[V1] Enable Mamba2 layers other than MambaMixer2 in the v1 engine (#20660)
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
|
||||
update_metadata)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
@@ -219,7 +220,7 @@ def mamba_v2_sharded_weight_loader(
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
@CustomOp.register("mamba_mixer2")
|
||||
class MambaMixer2(CustomOp):
|
||||
class MambaMixer2(MambaBase, CustomOp):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
@@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
chunk_size: int = -1, # the chunk size used by v1
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -428,10 +428,7 @@ class MambaMixer2(CustomOp):
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
assert chunk_size != -1, "chunk_size must be set for v1"
|
||||
|
||||
# NOTE: chunk_size may be -1 for models without v1 support
|
||||
self.chunk_size = chunk_size
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
|
||||
Reference in New Issue
Block a user