[Bugfix] Fix dtype mismatch in RMSNormGated.forward_native() during torch.compile (#35256)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -74,7 +74,7 @@ def layer_norm_ref(
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float32]
|
||||
DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# Test various M sizes to ensure rows_per_block logic works correctly
|
||||
NUM_TOKENS = [
|
||||
1,
|
||||
@@ -380,6 +380,68 @@ def test_multidimensional_input(
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 256, 1024])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("has_gate", [True, False])
|
||||
@pytest.mark.parametrize("group_size", [None, 64])
|
||||
@pytest.mark.parametrize("norm_before_gate", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_rmsnorm_gated_forward_native_dtype(
|
||||
default_vllm_config,
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
has_gate: bool,
|
||||
group_size: int | None,
|
||||
norm_before_gate: bool,
|
||||
):
|
||||
"""Test that RMSNormGated.forward_native preserves input dtype."""
|
||||
if group_size is not None and hidden_size % group_size != 0:
|
||||
pytest.skip(
|
||||
f"hidden_size {hidden_size} not divisible by group_size {group_size}"
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNormGated
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
set_random_seed(42)
|
||||
|
||||
layer = RMSNormGated(
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
z = (
|
||||
torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
if has_gate
|
||||
else None
|
||||
)
|
||||
|
||||
out = layer.forward_native(x, z)
|
||||
|
||||
# Verify dtype preservation
|
||||
assert out.dtype == dtype, f"Expected {dtype}, got {out.dtype}"
|
||||
|
||||
# Verify numerical correctness against reference
|
||||
ref_out = rms_norm_ref(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.bias,
|
||||
z=z,
|
||||
eps=1e-5,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
upcast=True,
|
||||
)
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run a quick smoke test
|
||||
test_layer_norm_fwd_basic(128, 1024, torch.float16, 42, False)
|
||||
|
||||
@@ -557,6 +557,11 @@ class RMSNormGated(CustomOp):
|
||||
- norm_before_gate=True: out = norm(x) * silu(z)
|
||||
- norm_before_gate=False: out = norm(x * silu(z))
|
||||
"""
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
weight = self.weight.float()
|
||||
z = z.float() if z is not None else None
|
||||
|
||||
# Apply gating before normalization if needed
|
||||
if z is not None and not self.norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
@@ -566,7 +571,7 @@ class RMSNormGated(CustomOp):
|
||||
# Standard RMS norm across the last dimension
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x_normed = x * torch.rsqrt(variance + self.eps)
|
||||
out = x_normed * self.weight
|
||||
out = x_normed * weight
|
||||
else:
|
||||
# Group RMS norm
|
||||
from einops import rearrange
|
||||
@@ -574,13 +579,13 @@ class RMSNormGated(CustomOp):
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
|
||||
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
|
||||
x_normed = x_group * torch.rsqrt(variance + self.eps)
|
||||
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
|
||||
out = rearrange(x_normed, "... g d -> ... (g d)") * weight
|
||||
|
||||
# Apply gating after normalization if needed
|
||||
if z is not None and self.norm_before_gate:
|
||||
out = out * F.silu(z)
|
||||
|
||||
return out.to(x.dtype)
|
||||
return out.to(orig_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self, x: torch.Tensor, z: torch.Tensor | None = None
|
||||
|
||||
Reference in New Issue
Block a user