[V1] Enable Mamba2 layers other than MambaMixer2 in the v1 engine (#20660)

Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
This commit is contained in:
nopperl
2025-07-11 14:53:31 +09:00
committed by GitHub
parent 31d5c1797f
commit 5d09152ff1
11 changed files with 68 additions and 45 deletions

View File

@@ -17,6 +17,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@@ -219,7 +220,7 @@ def mamba_v2_sharded_weight_loader(
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer2")
class MambaMixer2(CustomOp):
class MambaMixer2(MambaBase, CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
@@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
chunk_size: int = -1, # the chunk size used by v1
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
@@ -428,10 +428,7 @@ class MambaMixer2(CustomOp):
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert chunk_size != -1, "chunk_size must be set for v1"
# NOTE: chunk_size may be -1 for models without v1 support
self.chunk_size = chunk_size
self.prefix = prefix
def forward_native(