[Hardware][AMD][CI][Bugfix] Fix AMD Quantization test group (#31713)
Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -36,7 +36,9 @@ MODELS = [
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"force_marlin", [False] if current_platform.is_rocm() else [False, True]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
@@ -125,7 +127,9 @@ def test_kv_cache_model_load_and_run(
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"force_marlin", [False] if current_platform.is_rocm() else [False, True]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
@@ -197,10 +201,10 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
def quantize_ref(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
finfo = torch.finfo(current_platform.fp8_dtype())
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
qweight = qweight.to(current_platform.fp8_dtype())
|
||||
return qweight
|
||||
|
||||
def per_tensor_dequantize(tensor, inv_scale, dtype):
|
||||
@@ -267,6 +271,10 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="FP8 e4m3fn weight reloading is not supported on e4m3fnuz platforms",
|
||||
)
|
||||
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
|
||||
# FP8 weight reloading does not support online quantization
|
||||
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
|
||||
|
||||
Reference in New Issue
Block a user