[CustomOp] CustomOp FusedRMSNormGated (#35877)

Signed-off-by: Elias Ellison <elias.ellison@gmail.com>
Signed-off-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
eellison
2026-03-06 13:53:37 -05:00
committed by GitHub
parent 26bd43b52d
commit f3c6c9c9d7
2 changed files with 133 additions and 2 deletions

View File

@@ -12,6 +12,7 @@
import torch
import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, next_power_of_2
@@ -431,7 +432,8 @@ def rms_norm_gated(
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
class FusedRMSNormGated(nn.Module):
@CustomOp.register("fused_rms_norm_gated")
class FusedRMSNormGated(CustomOp):
def __init__(
self,
hidden_size: int,
@@ -458,7 +460,33 @@ class FusedRMSNormGated(nn.Module):
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(
def forward_native(
self,
x: torch.Tensor,
g: torch.Tensor,
residual: torch.Tensor | None = None,
prenorm: bool = False,
residual_in_fp32: bool = False,
) -> torch.Tensor:
"""Decomposed PyTorch ops for torch.compile/inductor fusion."""
# TODO(https://github.com/vllm-project/vllm/issues/36175): implement
# native residual/prenorm path and unify with RMSNormGated.
# For now, fall back to the triton kernel.
if residual is not None or prenorm:
return self.forward_cuda(x, g, residual, prenorm, residual_in_fp32)
x_float = x.float()
variance = x_float.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_float * torch.rsqrt(variance + self.eps)
if self.weight is not None:
x_normed = x_normed * self.weight.float()
g_float = g.float()
if self.activation in ("swish", "silu"):
out = x_normed * g_float * torch.sigmoid(g_float)
else: # sigmoid
out = x_normed * torch.sigmoid(g_float)
return out.to(x.dtype)
def forward_cuda(
self,
x: torch.Tensor,
g: torch.Tensor,