diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index adc643c38..e2575a2b4 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, ) from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -41,8 +41,8 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer # --8<-- [start:mamba_mixer] -@CustomOp.register("mamba_mixer") -class MambaMixer(MambaBase, CustomOp): +@PluggableLayer.register("mamba_mixer") +class MambaMixer(MambaBase, PluggableLayer): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent @@ -230,10 +230,7 @@ class MambaMixer(MambaBase, CustomOp): self.prefix, ) - def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor): - pass - - def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): + def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): """ Run the Mamba-1 SSM pipeline. @@ -528,7 +525,7 @@ def mamba_mixer( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, output=output) + self.forward_impl(hidden_states=hidden_states, output=output) def mamba_mixer_fake( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index f602d9b62..c325a0381 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -14,7 +14,7 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import CustomOp, PluggableLayer from vllm.model_executor.layers.linear import ( ColumnParallelLinear, RowParallelLinear, @@ -219,8 +219,8 @@ def mamba_v2_sharded_weight_loader( # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer # --8<-- [start:mamba_mixer2] -@CustomOp.register("mamba_mixer2") -class MambaMixer2(MambaBase, CustomOp): +@PluggableLayer.register("mamba_mixer2") +class MambaMixer2(MambaBase, PluggableLayer): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent @@ -472,13 +472,6 @@ class MambaMixer2(MambaBase, CustomOp): # Check if running on Blackwell (SM100+) for kernel tuning self.is_blackwell = current_platform.is_device_capability_family(100) - def forward_native( - self, - hidden_states: torch.Tensor, - mup_vector: torch.Tensor | None = None, - ): - pass - def forward( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 2bc89cc23..68f0b9550 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -14,7 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm @@ -107,8 +107,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # transformers.models.mamba.modeling_mamba.MambaMixer # --8<-- [start:plamo2_mamba_mixer] -@CustomOp.register("plamo2_mamba_mixer") -class Plamo2MambaMixer(MambaBase, CustomOp): +@PluggableLayer.register("plamo2_mamba_mixer") +class Plamo2MambaMixer(MambaBase, PluggableLayer): # --8<-- [end:plamo2_mamba_mixer] def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None: @@ -233,14 +233,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp): dt = self.dt_proj(time_step) return B, C, dt - def forward_native( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - **kwargs, - ): - pass - def forward( self, hidden_states: torch.Tensor, @@ -253,7 +245,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp): self.prefix, ) - def forward_cuda( + def forward_impl( self, hidden_states: torch.Tensor, output: torch.Tensor, @@ -494,7 +486,7 @@ def plamo2_mamba_mixer( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, output=output) + self.forward_impl(hidden_states=hidden_states, output=output) def plamo2_mamba_mixer_fake(