[PERF] Decouple projections from GDN custom op (#27512)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2025-11-04 20:11:41 +04:00
committed by GitHub
parent 97e3dda84b
commit 5fd8f02ea9
3 changed files with 204 additions and 53 deletions

View File

@@ -12,6 +12,7 @@ from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@@ -369,6 +370,107 @@ class GemmaRMSNorm(CustomOp):
return self.forward_native(x, residual)
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
"""RMS Normalization with optional gating.
This is a native PyTorch implementation that supports:
- Standard RMS normalization
- Group RMS normalization
- Optional gating with SiLU activation
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
group_size: int | None = None,
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""Initialize RMSNormGated.
Args:
hidden_size: Size of the hidden dimension
eps: Epsilon for numerical stability
group_size: If not None, do GroupNorm with each group
having group_size elements.
group_size=None is equivalent to group_size=hidden_size
(i.e. there's only 1 group).
norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
If False and z is provided: out = norm(x * silu(z))
device: Device to create parameters on
dtype: Data type for parameters
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward_native(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
"""
Native PyTorch implementation of RMS normalization with gating.
Args:
x: Input tensor
z: Optional gating tensor
Returns:
Normalized (and optionally gated) tensor
If z is not None:
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z))
"""
# Apply gating before normalization if needed
if z is not None and not self.norm_before_gate:
x = x * F.silu(z)
# RMS Normalization
if self.group_size is None:
# Standard RMS norm across the last dimension
variance = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps)
out = x_normed * self.weight
else:
# Group RMS norm
from einops import rearrange
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_group * torch.rsqrt(variance + self.eps)
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
# Apply gating after normalization if needed
if z is not None and self.norm_before_gate:
out = out * F.silu(z)
return out
def forward_cuda(
self, x: torch.Tensor, z: torch.Tensor | None = None
) -> torch.Tensor:
return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)
class LayerNorm(nn.Module):
"""
Layer Normalization.