[vLLM IR] rework gemma_rms_norm (#39014)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.quant_utils import FP8_DTYPE
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
@@ -162,3 +162,31 @@ def test_fused_rms_norm_quant(
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_gemma_rms_norm_mixed_input_weight_dtype(default_vllm_config) -> None:
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required")
|
||||
|
||||
device = CUDA_DEVICES[0]
|
||||
torch.set_default_device(device)
|
||||
|
||||
num_tokens, hidden_size = 32, 1024
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
layer = GemmaRMSNorm(hidden_size, eps=1e-6).to(device=device)
|
||||
layer.weight.data.normal_(mean=0.0, std=0.1)
|
||||
|
||||
# Gemma uses fp32 weight parameter while activations can be bf16.
|
||||
assert layer.weight.dtype == torch.float32
|
||||
out = layer(x)
|
||||
|
||||
x_fp32 = x.float()
|
||||
weight_fp32 = layer.weight.data.float() + 1.0
|
||||
variance = x_fp32.pow(2).mean(dim=-1, keepdim=True)
|
||||
ref = (x_fp32 * torch.rsqrt(variance + layer.variance_epsilon) * weight_fp32).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
assert out.dtype == x.dtype
|
||||
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user