diff --git a/tests/kernels/core/test_fused_rms_norm_gated.py b/tests/kernels/core/test_fused_rms_norm_gated.py new file mode 100644 index 000000000..793dd02a9 --- /dev/null +++ b/tests/kernels/core/test_fused_rms_norm_gated.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests that FusedRMSNormGated decomposes correctly under torch.compile, +matching the eager triton kernel output.""" + +import pytest +import torch + +from vllm.model_executor.layers.fla.ops.kda import FusedRMSNormGated +from vllm.utils.torch_utils import set_random_seed + +DTYPES = [torch.bfloat16] +HIDDEN_SIZES = [128, 512] +NUM_TOKENS = [64, 128] +ACTIVATIONS = ["swish", "sigmoid"] +ELEMENTWISE_AFFINE = [True, False] +SEEDS = [0] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("elementwise_affine", ELEMENTWISE_AFFINE) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_compiled_vs_eager( + default_vllm_config, + num_tokens: int, + hidden_size: int, + activation: str, + elementwise_affine: bool, + dtype: torch.dtype, + seed: int, +) -> None: + """forward_native decomposition matches forward_cuda triton kernel.""" + torch._dynamo.reset() + set_random_seed(seed) + device = torch.device("cuda:0") + + module = FusedRMSNormGated( + hidden_size, + elementwise_affine=elementwise_affine, + eps=1e-5, + activation=activation, + device=device, + dtype=dtype, + ) + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + g = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + + # forward_cuda may modify x in-place, so clone inputs + cuda_out = module.forward_cuda(x.clone(), g.clone()) + compiled_native = torch.compile(module.forward_native, fullgraph=True) + native_out = compiled_native(x.clone(), g.clone()) + + torch.testing.assert_close(native_out, cuda_out, atol=1e-3, rtol=1e-2) + + +@pytest.mark.parametrize( + "shape", + [ + (1, 16, 32, 128), + (2, 8, 16, 64), + ], +) +@pytest.mark.parametrize("activation", ACTIVATIONS) +@pytest.mark.parametrize("elementwise_affine", ELEMENTWISE_AFFINE) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_compiled_vs_eager_multidim( + default_vllm_config, + shape: tuple, + activation: str, + elementwise_affine: bool, + dtype: torch.dtype, + seed: int, +) -> None: + """forward_native decomposition handles multi-dimensional inputs.""" + torch._dynamo.reset() + set_random_seed(seed) + device = torch.device("cuda:0") + head_dim = shape[-1] + + module = FusedRMSNormGated( + head_dim, + elementwise_affine=elementwise_affine, + eps=1e-5, + activation=activation, + device=device, + dtype=dtype, + ) + x = torch.randn(*shape, dtype=dtype, device=device) + g = torch.randn(*shape, dtype=dtype, device=device) + + # forward_cuda may modify x in-place, so clone inputs + cuda_out = module.forward_cuda(x.clone(), g.clone()) + compiled_native = torch.compile(module.forward_native, fullgraph=True) + native_out = compiled_native(x.clone(), g.clone()) + + torch.testing.assert_close(native_out, cuda_out, atol=1e-3, rtol=1e-2) diff --git a/vllm/model_executor/layers/fla/ops/kda.py b/vllm/model_executor/layers/fla/ops/kda.py index 7145933e7..460be44c8 100644 --- a/vllm/model_executor/layers/fla/ops/kda.py +++ b/vllm/model_executor/layers/fla/ops/kda.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn +from vllm.model_executor.custom_op import CustomOp from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv, next_power_of_2 @@ -431,7 +432,8 @@ def rms_norm_gated( return y if not prenorm else (y, residual_out.reshape(x_shape_og)) -class FusedRMSNormGated(nn.Module): +@CustomOp.register("fused_rms_norm_gated") +class FusedRMSNormGated(CustomOp): def __init__( self, hidden_size: int, @@ -458,7 +460,33 @@ class FusedRMSNormGated(nn.Module): self.register_parameter("weight", None) self.register_parameter("bias", None) - def forward( + def forward_native( + self, + x: torch.Tensor, + g: torch.Tensor, + residual: torch.Tensor | None = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + ) -> torch.Tensor: + """Decomposed PyTorch ops for torch.compile/inductor fusion.""" + # TODO(https://github.com/vllm-project/vllm/issues/36175): implement + # native residual/prenorm path and unify with RMSNormGated. + # For now, fall back to the triton kernel. + if residual is not None or prenorm: + return self.forward_cuda(x, g, residual, prenorm, residual_in_fp32) + x_float = x.float() + variance = x_float.pow(2).mean(dim=-1, keepdim=True) + x_normed = x_float * torch.rsqrt(variance + self.eps) + if self.weight is not None: + x_normed = x_normed * self.weight.float() + g_float = g.float() + if self.activation in ("swish", "silu"): + out = x_normed * g_float * torch.sigmoid(g_float) + else: # sigmoid + out = x_normed * torch.sigmoid(g_float) + return out.to(x.dtype) + + def forward_cuda( self, x: torch.Tensor, g: torch.Tensor,