[Model] Support Mamba (#6484)

This commit is contained in:
Tyler Michael Smith
2024-10-11 11:40:06 -04:00
committed by GitHub
parent df3dcdf49d
commit 7342a7d7f8
29 changed files with 1603 additions and 343 deletions

View File

@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
@@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@lru_cache(maxsize=None)
def get_attn_backend(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
@@ -105,9 +105,8 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
backend = which_attn_to_use(head_size, sliding_window, dtype,
kv_cache_dtype, block_size, is_attention_free)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
@@ -146,23 +145,31 @@ def get_attn_backend(
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.NO_ATTENTION:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
else:
raise ValueError("Invalid attention backend.")
def which_attn_to_use(
num_heads: int,
head_size: int,
num_kv_heads: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
return _Backend.NO_ATTENTION
# Check whether a particular choice of backend was
# previously forced.
#