[Model][Mamba] Add selector for mamba attention backend and make it pluggable for other device (#26487)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-11-20 00:24:55 +08:00
committed by GitHub
parent 48fc8b1e59
commit d44e9df7d4
12 changed files with 144 additions and 85 deletions

View File

@@ -12,7 +12,11 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
AttentionBackendEnum,
MambaAttentionBackendEnum,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils import STR_BACKEND_ENV_VAR
@@ -197,6 +201,33 @@ def _cached_get_attn_backend(
return backend
def get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
"""Select which mamba attention backend to use and lazily import it."""
return _cached_get_mamba_attn_backend(mamba_type)
@cache
def _cached_get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
assert mamba_type and isinstance(mamba_type, str)
selected_backend = None
try:
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
selected_backend = MambaAttentionBackendEnum[backend_name]
except KeyError as e:
raise ValueError(
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
) from e
mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: AttentionBackendEnum,