[Attention] Unify mamba and attention backend selection (#23171)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
Ayush Satyam
2025-08-25 14:39:36 +05:30
committed by GitHub
parent d0a4a3f645
commit 5c4b6e66fe
11 changed files with 186 additions and 72 deletions

View File

@@ -1,12 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
class MambaBase(ABC):
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
@@ -32,3 +38,8 @@ class MambaBase(ABC):
@abstractmethod
def mamba_type(self) -> str:
pass
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass

View File

@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import (
Mamba1AttentionBackend)
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()

View File

@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
return Mamba2AttentionBackend
def mamba_mixer2(
hidden_states: torch.Tensor,

View File

@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
@@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "short_conv"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
return ShortConvAttentionBackend
def short_conv(
hidden_states: torch.Tensor,