[Hardware][AMD][CI][Bugfix] Fix AMD Quantization test group (#31713)

Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
Matt
2026-01-11 01:19:46 -06:00
committed by GitHub
parent 9103ed1696
commit bde57ab2ed
12 changed files with 114 additions and 52 deletions

View File

@@ -36,7 +36,9 @@ MODELS = [
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
@@ -125,7 +127,9 @@ def test_kv_cache_model_load_and_run(
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"force_marlin", [False] if current_platform.is_rocm() else [False, True]
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
@@ -197,10 +201,10 @@ def test_scaled_fp8_quant(dtype) -> None:
def quantize_ref(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
finfo = torch.finfo(current_platform.fp8_dtype())
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
qweight = qweight.to(current_platform.fp8_dtype())
return qweight
def per_tensor_dequantize(tensor, inv_scale, dtype):
@@ -267,6 +271,10 @@ def test_scaled_fp8_quant(dtype) -> None:
)
@pytest.mark.skipif(
current_platform.is_fp8_fnuz(),
reason="FP8 e4m3fn weight reloading is not supported on e4m3fnuz platforms",
)
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
# FP8 weight reloading does not support online quantization
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False