[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-02-07 21:26:05 +08:00
committed by GitHub
parent db4ede9743
commit ce9b3cd3e9
3 changed files with 13 additions and 31 deletions

View File

@@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -41,8 +41,8 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:mamba_mixer]
@CustomOp.register("mamba_mixer")
class MambaMixer(MambaBase, CustomOp):
@PluggableLayer.register("mamba_mixer")
class MambaMixer(MambaBase, PluggableLayer):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
@@ -230,10 +230,7 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix,
)
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
pass
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
"""
Run the Mamba-1 SSM pipeline.
@@ -528,7 +525,7 @@ def mamba_mixer(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output)
self.forward_impl(hidden_states=hidden_states, output=output)
def mamba_mixer_fake(