[Model][Mamba] Add selector for mamba attention backend and make it pluggable for other device (#26487)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -12,7 +12,11 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.backends.registry import (
|
||||
MAMBA_TYPE_TO_BACKEND_MAP,
|
||||
AttentionBackendEnum,
|
||||
MambaAttentionBackendEnum,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
@@ -197,6 +201,33 @@ def _cached_get_attn_backend(
|
||||
return backend
|
||||
|
||||
|
||||
def get_mamba_attn_backend(
|
||||
mamba_type: str,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Select which mamba attention backend to use and lazily import it."""
|
||||
return _cached_get_mamba_attn_backend(mamba_type)
|
||||
|
||||
|
||||
@cache
|
||||
def _cached_get_mamba_attn_backend(
|
||||
mamba_type: str,
|
||||
) -> type[AttentionBackend]:
|
||||
assert mamba_type and isinstance(mamba_type, str)
|
||||
|
||||
selected_backend = None
|
||||
try:
|
||||
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
|
||||
selected_backend = MambaAttentionBackendEnum[backend_name]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid mamba attention backend type: '{backend_name}'. Valid "
|
||||
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}"
|
||||
) from e
|
||||
|
||||
mamba_attn_backend = selected_backend.get_class()
|
||||
return mamba_attn_backend
|
||||
|
||||
|
||||
@contextmanager
|
||||
def global_force_attn_backend_context_manager(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
|
||||
Reference in New Issue
Block a user