[Model] Replace Mamba2 RMSNorm Gated with Fused Triton Kernel (#20839)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Yu Chin Fabian Lim <fabian.lim@gmail.com> Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Yu Chin Fabian Lim <fabian.lim@gmail.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
extra_groups_for_head_shards, get_mamba_state_shape)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
@@ -133,21 +134,15 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
return x * nn.functional.silu(gate.to(
|
||||
torch.float32)).to(input_dtype)
|
||||
|
||||
if self.tp_size > 1 or self.n_groups != 1:
|
||||
if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1):
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# cast x and gate to float32 before silu
|
||||
out = torch.empty_like(x)
|
||||
y = x * nn.functional.silu(gate.to(torch.float32))
|
||||
ops.rms_norm(
|
||||
out,
|
||||
y.to(x.dtype),
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
return rms_norm_gated(x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False)
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
|
||||
Reference in New Issue
Block a user