[Model][Mamba] Add selector for mamba attention backend and make it pluggable for other device (#26487)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import get_mamba_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
@@ -38,11 +39,6 @@ class MambaBase(AttentionLayerBase):
|
||||
def mamba_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
||||
pass
|
||||
@@ -69,3 +65,7 @@ class MambaBase(AttentionLayerBase):
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
return get_mamba_attn_backend(self.mamba_type)
|
||||
|
||||
@@ -2,12 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -37,9 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
name = "MiniMaxText01RMSNormTP"
|
||||
@@ -123,11 +114,6 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
|
||||
return LinearAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -452,11 +449,6 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba1"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
def _time_proj_bias(self) -> torch.Tensor | None:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -908,11 +904,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
projected_states: torch.Tensor,
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
|
||||
@@ -232,11 +228,6 @@ class ShortConv(MambaBase, CustomOp):
|
||||
def mamba_type(self) -> str:
|
||||
return "short_conv"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
|
||||
def short_conv(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user