[PluggableLayer][3/N] Apply PluggableLayer to mamba layers. (#33660)
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -107,8 +107,8 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
|
||||
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
|
||||
# transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
# --8<-- [start:plamo2_mamba_mixer]
|
||||
@CustomOp.register("plamo2_mamba_mixer")
|
||||
class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
@PluggableLayer.register("plamo2_mamba_mixer")
|
||||
class Plamo2MambaMixer(MambaBase, PluggableLayer):
|
||||
# --8<-- [end:plamo2_mamba_mixer]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None:
|
||||
@@ -233,14 +233,6 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
dt = self.dt_proj(time_step)
|
||||
return B, C, dt
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -253,7 +245,7 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
def forward_impl(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
@@ -494,7 +486,7 @@ def plamo2_mamba_mixer(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
self.forward_impl(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def plamo2_mamba_mixer_fake(
|
||||
|
||||
Reference in New Issue
Block a user