[CI][BugFix][AMD][FP8] Fix test_rms_norm so it runs correctly on ROCm (#32372)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -14,9 +14,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
|
||||
QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()]
|
||||
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
|
||||
# Avoid combinatorial explosion with full Cartesian product
|
||||
NUM_TOKENS_HIDDEN_SIZES = [
|
||||
@@ -61,14 +62,14 @@ def ref_dynamic_per_token_or_block_quant(
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
assert quant_dtype == current_platform.fp8_dtype()
|
||||
|
||||
# Norm
|
||||
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
|
||||
|
||||
# Quant
|
||||
if group_size is not None:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
if quant_dtype == current_platform.fp8_dtype():
|
||||
torch_out, scales = per_token_group_quant_fp8(
|
||||
torch_out, group_size=group_size[1], use_ue8m0=False
|
||||
)
|
||||
@@ -78,7 +79,7 @@ def ref_dynamic_per_token_or_block_quant(
|
||||
torch_out, group_size=group_size[1]
|
||||
)
|
||||
else:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
if quant_dtype == current_platform.fp8_dtype():
|
||||
torch_out, scales = ops.scaled_fp8_quant(
|
||||
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
@@ -162,6 +163,7 @@ def test_rms_norm(
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
if group_size is not None and hidden_size % group_size[1] != 0:
|
||||
# skip
|
||||
@@ -171,7 +173,7 @@ def test_rms_norm(
|
||||
# blockwise baseline doesn't support scale_ub
|
||||
return
|
||||
|
||||
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
|
||||
if has_scale_ub and quant_dtype != current_platform.fp8_dtype():
|
||||
# skip
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user