[Misc] Enhance attention selector (#4751)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from functools import lru_cache
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,8 +21,18 @@ class _Backend(enum.Enum):
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||
backend = _which_attn_to_use(dtype)
|
||||
def get_attn_backend(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
) -> Type[AttentionBackend]:
|
||||
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
|
||||
sliding_window, dtype, kv_cache_dtype,
|
||||
block_size)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using FlashAttention-2 backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
@@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||
return TorchSDPABackend
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
logger.info("Using Flashinfer backend.")
|
||||
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
|
||||
logger.warning("Eager mode is enforced for the Flashinfer backend.")
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||
def _which_attn_to_use(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
if is_cpu():
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
Reference in New Issue
Block a user