[Hardware][AMD][CI][Bugfix] Fix AMD Quantization test group (#31713)

Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Matt
2026-01-11 01:19:46 -06:00
committed by GitHub
parent 9103ed1696
commit bde57ab2ed
12 changed files with 114 additions and 52 deletions

View File

@@ -6,18 +6,12 @@ Run `pytest tests/quantization/test_ptpc_fp8.py --forked`.
"""
import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod
from vllm.platforms import current_platform
UNSUPPORTED_STR = (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only "
"support output dtype of bfloat16. torch.float16 is specified."
)
@pytest.fixture(scope="function", autouse=True)
def enable_pickle(monkeypatch):
@@ -30,24 +24,17 @@ def enable_pickle(monkeypatch):
reason="PTPC FP8 is not supported on this GPU type.",
)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="This test is for ROCm GPU.")
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
try:
llm = vllm_runner(
"facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
enforce_eager=True,
kv_cache_dtype=kv_cache_dtype,
)
except AssertionError as e:
if str(e) == UNSUPPORTED_STR:
# If the error message matches, the test passes
return
else:
# If the error message does not match, re-raise the exception
raise
llm = vllm_runner(
"facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
enforce_eager=True,
kv_cache_dtype=kv_cache_dtype,
allow_deprecated_quantization=True,
)
with llm:
@@ -60,9 +47,9 @@ def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0
# For GPUs with hardware support, we keep weights in fp8
if current_platform.has_device_capability(94):
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
assert fc1.weight.dtype == current_platform.fp8_dtype()
llm.apply_model(check_model)