[CI][BugFix][AMD][FP8] Fix test_rms_norm so it runs correctly on ROCm (#32372)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2026-01-15 05:05:54 -06:00
committed by GitHub
parent c5891b5430
commit 8853a50af2

View File

@@ -14,9 +14,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8,
)
from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()]
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
# Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [
@@ -61,14 +62,14 @@ def ref_dynamic_per_token_or_block_quant(
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
assert quant_dtype == current_platform.fp8_dtype()
# Norm
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
# Quant
if group_size is not None:
if quant_dtype == torch.float8_e4m3fn:
if quant_dtype == current_platform.fp8_dtype():
torch_out, scales = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)
@@ -78,7 +79,7 @@ def ref_dynamic_per_token_or_block_quant(
torch_out, group_size=group_size[1]
)
else:
if quant_dtype == torch.float8_e4m3fn:
if quant_dtype == current_platform.fp8_dtype():
torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
@@ -162,6 +163,7 @@ def test_rms_norm(
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
torch.cuda.set_device(device)
if group_size is not None and hidden_size % group_size[1] != 0:
# skip
@@ -171,7 +173,7 @@ def test_rms_norm(
# blockwise baseline doesn't support scale_ub
return
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
if has_scale_ub and quant_dtype != current_platform.fp8_dtype():
# skip
return