[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -41,8 +41,8 @@ from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
# --8<-- [start:mamba_mixer]
|
||||
@CustomOp.register("mamba_mixer")
|
||||
class MambaMixer(MambaBase, CustomOp):
|
||||
@PluggableLayer.register("mamba_mixer")
|
||||
class MambaMixer(MambaBase, PluggableLayer):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
@@ -230,10 +230,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
|
||||
@@ -528,7 +525,7 @@ def mamba_mixer(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
self.forward_impl(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
|
||||
Reference in New Issue
Block a user