[vLLM IR] rework gemma_rms_norm (#39014)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jiangyun Zhu
2026-04-07 16:37:00 +08:00
committed by GitHub
parent da4c0e4db9
commit 8060bb0333
8 changed files with 106 additions and 75 deletions

View File

@@ -6,7 +6,7 @@ import torch
from tests.kernels.quant_utils import FP8_DTYPE
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@@ -162,3 +162,31 @@ def test_fused_rms_norm_quant(
atol=1e-3,
rtol=1e-3,
)
@torch.inference_mode()
def test_gemma_rms_norm_mixed_input_weight_dtype(default_vllm_config) -> None:
if not torch.cuda.is_available():
pytest.skip("CUDA required")
device = CUDA_DEVICES[0]
torch.set_default_device(device)
num_tokens, hidden_size = 32, 1024
x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
layer = GemmaRMSNorm(hidden_size, eps=1e-6).to(device=device)
layer.weight.data.normal_(mean=0.0, std=0.1)
# Gemma uses fp32 weight parameter while activations can be bf16.
assert layer.weight.dtype == torch.float32
out = layer(x)
x_fp32 = x.float()
weight_fp32 = layer.weight.data.float() + 1.0
variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)
ref = (x_fp32 * torch.rsqrt(variance + layer.variance_epsilon) * weight_fp32).to(
x.dtype
)
assert out.dtype == x.dtype
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)