[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2026-02-07 21:26:05 +08:00
committed by GitHub
parent db4ede9743
commit ce9b3cd3e9
3 changed files with 13 additions and 31 deletions

View File

@@ -14,7 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -107,8 +107,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer
# --8<-- [start:plamo2_mamba_mixer]
@CustomOp.register("plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp):
@PluggableLayer.register("plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, PluggableLayer):
# --8<-- [end:plamo2_mamba_mixer]
def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None:
@@ -233,14 +233,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
dt = self.dt_proj(time_step)
return B, C, dt
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
**kwargs,
):
pass
def forward(
self,
hidden_states: torch.Tensor,
@@ -253,7 +245,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
self.prefix,
)
def forward_cuda(
def forward_impl(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
@@ -494,7 +486,7 @@ def plamo2_mamba_mixer(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output)
self.forward_impl(hidden_states=hidden_states, output=output)
def plamo2_mamba_mixer_fake(