diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 9f2b39199..fcc67e463 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -19,4 +19,7 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer[s3,gcs]==0.15.3 conch-triton-kernels==1.2.1 -timm>=1.0.17 \ No newline at end of file +timm>=1.0.17 +# amd-quark: required for Quark quantization on ROCm +# To be consistent with test_quark.py +amd-quark>=0.8.99 \ No newline at end of file diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 0ff6e8407..a560494a4 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -26,9 +26,12 @@ from vllm.platforms import current_platform from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch +# Minimum amd-quark version for MXFP4/OCP_MX tests (single source of truth). +QUARK_MXFP4_MIN_VERSION = "0.8.99" + QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( importlib.metadata.version("amd-quark") -) >= version.parse("0.8.99") +) >= version.parse(QUARK_MXFP4_MIN_VERSION) if QUARK_MXFP4_AVAILABLE: from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer @@ -200,7 +203,10 @@ WIKITEXT_ACCURACY_CONFIGS = [ ] -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not QUARK_MXFP4_AVAILABLE, + reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", +) @pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS) @pytest.mark.parametrize("tp_size", [1, 2]) def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): @@ -231,7 +237,10 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int): @pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not QUARK_MXFP4_AVAILABLE, + reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", +) @pytest.mark.skipif( not HF_HUB_AMD_ORG_ACCESS, reason="Read access to huggingface.co/amd is required for this test.", @@ -261,7 +270,10 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not QUARK_MXFP4_AVAILABLE, + reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", +) @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]): @@ -289,7 +301,10 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[in ) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not QUARK_MXFP4_AVAILABLE, + reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available", +) @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) def test_mxfp4_dequant_kernel_match_quark(