[CustomOp] CustomOp FusedRMSNormGated (#35877)
Signed-off-by: Elias Ellison <elias.ellison@gmail.com> Signed-off-by: eellison <elias.ellison@gmail.com>
This commit is contained in:
103
tests/kernels/core/test_fused_rms_norm_gated.py
Normal file
103
tests/kernels/core/test_fused_rms_norm_gated.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""Tests that FusedRMSNormGated decomposes correctly under torch.compile,
|
||||||
|
matching the eager triton kernel output."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fla.ops.kda import FusedRMSNormGated
|
||||||
|
from vllm.utils.torch_utils import set_random_seed
|
||||||
|
|
||||||
|
DTYPES = [torch.bfloat16]
|
||||||
|
HIDDEN_SIZES = [128, 512]
|
||||||
|
NUM_TOKENS = [64, 128]
|
||||||
|
ACTIVATIONS = ["swish", "sigmoid"]
|
||||||
|
ELEMENTWISE_AFFINE = [True, False]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("activation", ACTIVATIONS)
|
||||||
|
@pytest.mark.parametrize("elementwise_affine", ELEMENTWISE_AFFINE)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_compiled_vs_eager(
|
||||||
|
default_vllm_config,
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
activation: str,
|
||||||
|
elementwise_affine: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
"""forward_native decomposition matches forward_cuda triton kernel."""
|
||||||
|
torch._dynamo.reset()
|
||||||
|
set_random_seed(seed)
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
module = FusedRMSNormGated(
|
||||||
|
hidden_size,
|
||||||
|
elementwise_affine=elementwise_affine,
|
||||||
|
eps=1e-5,
|
||||||
|
activation=activation,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||||
|
g = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# forward_cuda may modify x in-place, so clone inputs
|
||||||
|
cuda_out = module.forward_cuda(x.clone(), g.clone())
|
||||||
|
compiled_native = torch.compile(module.forward_native, fullgraph=True)
|
||||||
|
native_out = compiled_native(x.clone(), g.clone())
|
||||||
|
|
||||||
|
torch.testing.assert_close(native_out, cuda_out, atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shape",
|
||||||
|
[
|
||||||
|
(1, 16, 32, 128),
|
||||||
|
(2, 8, 16, 64),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("activation", ACTIVATIONS)
|
||||||
|
@pytest.mark.parametrize("elementwise_affine", ELEMENTWISE_AFFINE)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_compiled_vs_eager_multidim(
|
||||||
|
default_vllm_config,
|
||||||
|
shape: tuple,
|
||||||
|
activation: str,
|
||||||
|
elementwise_affine: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
"""forward_native decomposition handles multi-dimensional inputs."""
|
||||||
|
torch._dynamo.reset()
|
||||||
|
set_random_seed(seed)
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
head_dim = shape[-1]
|
||||||
|
|
||||||
|
module = FusedRMSNormGated(
|
||||||
|
head_dim,
|
||||||
|
elementwise_affine=elementwise_affine,
|
||||||
|
eps=1e-5,
|
||||||
|
activation=activation,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
x = torch.randn(*shape, dtype=dtype, device=device)
|
||||||
|
g = torch.randn(*shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# forward_cuda may modify x in-place, so clone inputs
|
||||||
|
cuda_out = module.forward_cuda(x.clone(), g.clone())
|
||||||
|
compiled_native = torch.compile(module.forward_native, fullgraph=True)
|
||||||
|
native_out = compiled_native(x.clone(), g.clone())
|
||||||
|
|
||||||
|
torch.testing.assert_close(native_out, cuda_out, atol=1e-3, rtol=1e-2)
|
||||||
@@ -12,6 +12,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
from vllm.utils.math_utils import cdiv, next_power_of_2
|
||||||
|
|
||||||
@@ -431,7 +432,8 @@ def rms_norm_gated(
|
|||||||
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
||||||
|
|
||||||
|
|
||||||
class FusedRMSNormGated(nn.Module):
|
@CustomOp.register("fused_rms_norm_gated")
|
||||||
|
class FusedRMSNormGated(CustomOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -458,7 +460,33 @@ class FusedRMSNormGated(nn.Module):
|
|||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def forward(
|
def forward_native(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
g: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None = None,
|
||||||
|
prenorm: bool = False,
|
||||||
|
residual_in_fp32: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Decomposed PyTorch ops for torch.compile/inductor fusion."""
|
||||||
|
# TODO(https://github.com/vllm-project/vllm/issues/36175): implement
|
||||||
|
# native residual/prenorm path and unify with RMSNormGated.
|
||||||
|
# For now, fall back to the triton kernel.
|
||||||
|
if residual is not None or prenorm:
|
||||||
|
return self.forward_cuda(x, g, residual, prenorm, residual_in_fp32)
|
||||||
|
x_float = x.float()
|
||||||
|
variance = x_float.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x_normed = x_float * torch.rsqrt(variance + self.eps)
|
||||||
|
if self.weight is not None:
|
||||||
|
x_normed = x_normed * self.weight.float()
|
||||||
|
g_float = g.float()
|
||||||
|
if self.activation in ("swish", "silu"):
|
||||||
|
out = x_normed * g_float * torch.sigmoid(g_float)
|
||||||
|
else: # sigmoid
|
||||||
|
out = x_normed * torch.sigmoid(g_float)
|
||||||
|
return out.to(x.dtype)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
g: torch.Tensor,
|
g: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user