[Model] Support Mamba (#6484)
This commit is contained in:
committed by
GitHub
parent
df3dcdf49d
commit
7342a7d7f8
@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
|
||||
FLASHINFER = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
@@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
@@ -105,9 +105,8 @@ def get_attn_backend(
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
|
||||
sliding_window, dtype, kv_cache_dtype,
|
||||
block_size)
|
||||
backend = which_attn_to_use(head_size, sliding_window, dtype,
|
||||
kv_cache_dtype, block_size, is_attention_free)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
@@ -146,23 +145,31 @@ def get_attn_backend(
|
||||
logger.info("Using Pallas backend.")
|
||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
||||
return PallasAttentionBackend
|
||||
elif backend == _Backend.NO_ATTENTION:
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
def which_attn_to_use(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
num_kv_heads: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
return _Backend.NO_ATTENTION
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user