17 lines
652 B
Python
17 lines
652 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
|
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
|
|
|
|
|
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
|
|
if mamba_type == "mamba1":
|
|
return Mamba1AttentionBackend
|
|
|
|
if mamba_type == "mamba2":
|
|
return Mamba2AttentionBackend
|
|
|
|
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
|
|
"supported yet.")
|