[Model] Replace Mamba2 RMSNorm Gated with Fused Triton Kernel (#20839)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Yu Chin Fabian Lim <fabian.lim@gmail.com>
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <fabian.lim@gmail.com>
This commit is contained in:
Chih-Chieh Yang
2025-07-25 09:49:36 -04:00
committed by GitHub
parent 9fe98d4250
commit eab2f3980c
2 changed files with 176 additions and 13 deletions

View File

@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
extra_groups_for_head_shards, get_mamba_state_shape)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
@@ -133,21 +134,15 @@ class Mixer2RMSNormGated(CustomOp):
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)
if self.tp_size > 1 or self.n_groups != 1:
if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1):
return self.forward_native(x, gate)
from vllm import _custom_ops as ops
# cast x and gate to float32 before silu
out = torch.empty_like(x)
y = x * nn.functional.silu(gate.to(torch.float32))
ops.rms_norm(
out,
y.to(x.dtype),
self.weight.data,
self.variance_epsilon,
)
return out
return rms_norm_gated(x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False)
def mamba_v2_sharded_weight_loader(