[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.custom_op import CustomOp, PluggableLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
@@ -219,8 +219,8 @@ def mamba_v2_sharded_weight_loader(
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
# --8<-- [start:mamba_mixer2]
|
||||
@CustomOp.register("mamba_mixer2")
|
||||
class MambaMixer2(MambaBase, CustomOp):
|
||||
@PluggableLayer.register("mamba_mixer2")
|
||||
class MambaMixer2(MambaBase, PluggableLayer):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
@@ -472,13 +472,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# Check if running on Blackwell (SM100+) for kernel tuning
|
||||
self.is_blackwell = current_platform.is_device_capability_family(100)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user