[ROCm] add amd-quark package in requirements for rocm to use quantized models (#35658)

Signed-off-by: Hongxia Yang <hongxiay.yang@amd.com>
Co-authored-by: Hongxia Yang <hongxiay.yang@amd.com>
This commit is contained in:
Hongxia Yang
2026-03-02 01:02:43 -05:00
committed by GitHub
parent 92f5d0f070
commit f26650d649
2 changed files with 24 additions and 6 deletions

View File

@@ -19,4 +19,7 @@ setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
runai-model-streamer[s3,gcs]==0.15.3
conch-triton-kernels==1.2.1
timm>=1.0.17
timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99

View File

@@ -26,9 +26,12 @@ from vllm.platforms import current_platform
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
# Minimum amd-quark version for MXFP4/OCP_MX tests (single source of truth).
QUARK_MXFP4_MIN_VERSION = "0.8.99"
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
) >= version.parse(QUARK_MXFP4_MIN_VERSION)
if QUARK_MXFP4_AVAILABLE:
from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
@@ -200,7 +203,10 @@ WIKITEXT_ACCURACY_CONFIGS = [
]
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
@@ -231,7 +237,10 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.",
@@ -261,7 +270,10 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
@@ -289,7 +301,10 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[in
)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not QUARK_MXFP4_AVAILABLE,
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(