[ROCm] add amd-quark package in requirements for rocm to use quantized models (#35658)
Signed-off-by: Hongxia Yang <hongxiay.yang@amd.com> Co-authored-by: Hongxia Yang <hongxiay.yang@amd.com>
This commit is contained in:
@@ -19,4 +19,7 @@ setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer[s3,gcs]==0.15.3
|
||||
conch-triton-kernels==1.2.1
|
||||
timm>=1.0.17
|
||||
timm>=1.0.17
|
||||
# amd-quark: required for Quark quantization on ROCm
|
||||
# To be consistent with test_quark.py
|
||||
amd-quark>=0.8.99
|
||||
@@ -26,9 +26,12 @@ from vllm.platforms import current_platform
|
||||
|
||||
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
|
||||
|
||||
# Minimum amd-quark version for MXFP4/OCP_MX tests (single source of truth).
|
||||
QUARK_MXFP4_MIN_VERSION = "0.8.99"
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")
|
||||
) >= version.parse("0.8.99")
|
||||
) >= version.parse(QUARK_MXFP4_MIN_VERSION)
|
||||
|
||||
if QUARK_MXFP4_AVAILABLE:
|
||||
from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
|
||||
@@ -200,7 +203,10 @@ WIKITEXT_ACCURACY_CONFIGS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not QUARK_MXFP4_AVAILABLE,
|
||||
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
|
||||
)
|
||||
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
@@ -231,7 +237,10 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not QUARK_MXFP4_AVAILABLE,
|
||||
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not HF_HUB_AMD_ORG_ACCESS,
|
||||
reason="Read access to huggingface.co/amd is required for this test.",
|
||||
@@ -261,7 +270,10 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not QUARK_MXFP4_AVAILABLE,
|
||||
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
|
||||
)
|
||||
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
|
||||
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
|
||||
@@ -289,7 +301,10 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[in
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not QUARK_MXFP4_AVAILABLE,
|
||||
reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
|
||||
)
|
||||
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
|
||||
def test_mxfp4_dequant_kernel_match_quark(
|
||||
|
||||
Reference in New Issue
Block a user