From 6290470843c131681e3e1318ae71070a34f33225 Mon Sep 17 00:00:00 2001 From: haosdent Date: Mon, 2 Mar 2026 04:14:46 +0800 Subject: [PATCH] [Bugfix] Fix dtype mismatch in RMSNormGated.forward_native() during torch.compile (#35256) Signed-off-by: haosdent --- tests/kernels/test_fla_layernorm_guard.py | 64 ++++++++++++++++++++++- vllm/model_executor/layers/layernorm.py | 11 ++-- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_fla_layernorm_guard.py b/tests/kernels/test_fla_layernorm_guard.py index 2ece5497c..4858ff2d7 100644 --- a/tests/kernels/test_fla_layernorm_guard.py +++ b/tests/kernels/test_fla_layernorm_guard.py @@ -74,7 +74,7 @@ def layer_norm_ref( return out.to(dtype) -DTYPES = [torch.bfloat16, torch.float32] +DTYPES = [torch.float16, torch.bfloat16, torch.float32] # Test various M sizes to ensure rows_per_block logic works correctly NUM_TOKENS = [ 1, @@ -380,6 +380,68 @@ def test_multidimensional_input( torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) +@pytest.mark.parametrize("num_tokens", [1, 128, 1024]) +@pytest.mark.parametrize("hidden_size", [64, 256, 1024]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("has_gate", [True, False]) +@pytest.mark.parametrize("group_size", [None, 64]) +@pytest.mark.parametrize("norm_before_gate", [True, False]) +@torch.inference_mode() +def test_rmsnorm_gated_forward_native_dtype( + default_vllm_config, + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + has_gate: bool, + group_size: int | None, + norm_before_gate: bool, +): + """Test that RMSNormGated.forward_native preserves input dtype.""" + if group_size is not None and hidden_size % group_size != 0: + pytest.skip( + f"hidden_size {hidden_size} not divisible by group_size {group_size}" + ) + + from vllm.model_executor.layers.layernorm import RMSNormGated + + device = torch.device("cuda:0") + set_random_seed(42) + + layer = RMSNormGated( + hidden_size, + eps=1e-5, + group_size=group_size, + norm_before_gate=norm_before_gate, + device=device, + dtype=dtype, + ) + + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + z = ( + torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + if has_gate + else None + ) + + out = layer.forward_native(x, z) + + # Verify dtype preservation + assert out.dtype == dtype, f"Expected {dtype}, got {out.dtype}" + + # Verify numerical correctness against reference + ref_out = rms_norm_ref( + x, + layer.weight, + layer.bias, + z=z, + eps=1e-5, + group_size=group_size, + norm_before_gate=norm_before_gate, + upcast=True, + ) + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + if __name__ == "__main__": # Run a quick smoke test test_layer_norm_fwd_basic(128, 1024, torch.float16, 42, False) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 72f42de06..2a1180dd6 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -557,6 +557,11 @@ class RMSNormGated(CustomOp): - norm_before_gate=True: out = norm(x) * silu(z) - norm_before_gate=False: out = norm(x * silu(z)) """ + orig_dtype = x.dtype + x = x.float() + weight = self.weight.float() + z = z.float() if z is not None else None + # Apply gating before normalization if needed if z is not None and not self.norm_before_gate: x = x * F.silu(z) @@ -566,7 +571,7 @@ class RMSNormGated(CustomOp): # Standard RMS norm across the last dimension variance = x.pow(2).mean(dim=-1, keepdim=True) x_normed = x * torch.rsqrt(variance + self.eps) - out = x_normed * self.weight + out = x_normed * weight else: # Group RMS norm from einops import rearrange @@ -574,13 +579,13 @@ class RMSNormGated(CustomOp): x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size) variance = x_group.pow(2).mean(dim=-1, keepdim=True) x_normed = x_group * torch.rsqrt(variance + self.eps) - out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight + out = rearrange(x_normed, "... g d -> ... (g d)") * weight # Apply gating after normalization if needed if z is not None and self.norm_before_gate: out = out * F.silu(z) - return out.to(x.dtype) + return out.to(orig_dtype) def forward_cuda( self, x: torch.Tensor, z: torch.Tensor | None = None