[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-02-07 21:26:05 +08:00
committed by GitHub
parent db4ede9743
commit ce9b3cd3e9
3 changed files with 13 additions and 31 deletions

View File

@@ -14,7 +14,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import CustomOp, PluggableLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
@@ -219,8 +219,8 @@ def mamba_v2_sharded_weight_loader(
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer2]
@CustomOp.register("mamba_mixer2")
class MambaMixer2(MambaBase, CustomOp):
@PluggableLayer.register("mamba_mixer2")
class MambaMixer2(MambaBase, PluggableLayer):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
@@ -472,13 +472,6 @@ class MambaMixer2(MambaBase, CustomOp):
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native(
self,
hidden_states: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
pass
def forward(
self,
hidden_states: torch.Tensor,