[CustomOp] CustomOp FusedRMSNormGated (#35877)
Signed-off-by: Elias Ellison <elias.ellison@gmail.com> Signed-off-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
||||
|
||||
@@ -431,7 +432,8 @@ def rms_norm_gated(
|
||||
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
||||
|
||||
|
||||
class FusedRMSNormGated(nn.Module):
|
||||
@CustomOp.register("fused_rms_norm_gated")
|
||||
class FusedRMSNormGated(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -458,7 +460,33 @@ class FusedRMSNormGated(nn.Module):
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
prenorm: bool = False,
|
||||
residual_in_fp32: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Decomposed PyTorch ops for torch.compile/inductor fusion."""
|
||||
# TODO(https://github.com/vllm-project/vllm/issues/36175): implement
|
||||
# native residual/prenorm path and unify with RMSNormGated.
|
||||
# For now, fall back to the triton kernel.
|
||||
if residual is not None or prenorm:
|
||||
return self.forward_cuda(x, g, residual, prenorm, residual_in_fp32)
|
||||
x_float = x.float()
|
||||
variance = x_float.pow(2).mean(dim=-1, keepdim=True)
|
||||
x_normed = x_float * torch.rsqrt(variance + self.eps)
|
||||
if self.weight is not None:
|
||||
x_normed = x_normed * self.weight.float()
|
||||
g_float = g.float()
|
||||
if self.activation in ("swish", "silu"):
|
||||
out = x_normed * g_float * torch.sigmoid(g_float)
|
||||
else: # sigmoid
|
||||
out = x_normed * torch.sigmoid(g_float)
|
||||
return out.to(x.dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user