[Feature][Quantization] MXFP4 support for MOE models (#17888)

Signed-off-by: Felix Marty <felmarty@amd.com>
Signed-off-by: Bowen Bao <bowenbao@amd.com>
Signed-off-by: Felix Marty <Felix.Marty@amd.com>
Co-authored-by: Bowen Bao <bowenbao@amd.com>
This commit is contained in:
fxmarty-amd
2025-07-09 22:19:02 +02:00
committed by GitHub
parent bf03ff3575
commit 332d4cb17b
15 changed files with 873 additions and 104 deletions

View File

@@ -3,15 +3,44 @@
"""Test model set-up and weight loading for quark-quantized models.
Run `pytest tests/quantization/test_quark.py`.
See also `tests/kernels/moe/test_mxfp4_moe.py`.
"""
import importlib
import importlib.metadata
import os
from dataclasses import dataclass
import huggingface_hub
import lm_eval
import pytest
import torch
from packaging import version
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.platforms import current_platform
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
"quark") is not None and version.parse(
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
if QUARK_MXFP4_AVAILABLE:
from quark.torch.export.nn.modules.realquantizer import (
StaticScaledRealQuantizer)
from quark.torch.kernel import mx as mx_kernel
from quark.torch.quantization.config.config import FP4PerGroupSpec
try:
huggingface_hub.list_repo_refs(
"amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ")
HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
HF_HUB_AMD_ORG_ACCESS = False
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
@@ -90,3 +119,145 @@ def test_quark_fp8_parity(vllm_runner):
for key in fp8_state_dict:
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
@dataclass
class ModelCase:
model_id: str
tp: int
@dataclass
class GSM8KAccuracyTestConfig:
model_name: str
excepted_value: float
def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
)
ACCURACY_CONFIGS = [
# Private model.
GSM8KAccuracyTestConfig(
model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
excepted_value=0.96),
]
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.")
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
if torch.cuda.device_count() < 8:
pytest.skip(
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
)
task = "gsm8k"
rtol = 0.03
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
tasks=task,
batch_size=64,
num_fewshot=8,
)
EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["exact_match,strict-match"]
assert (measured_value - rtol < EXPECTED_VALUE
and measured_value + rtol > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings",
[[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype,
scalings: list[int]):
torch.manual_seed(0)
hidden_size = 64 * 32
inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") -
0.5) * 2
for i in range(hidden_size // 32):
inp[:, i * 32:(i + 1) *
32] = inp[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)]
inp_kernel = inp.clone()
inp_kernel_clone = inp_kernel.clone()
res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
res_torch = qdq_mxfp4_torch(inp_kernel, "even")
for i in range(hidden_size // 32):
assert torch.all(torch.isfinite(res_hip[:, i * 32:(i + 1) * 32]))
assert torch.all(torch.isfinite(res_torch[:, i * 32:(i + 1) * 32]))
torch.testing.assert_close(res_hip[:, i * 32:(i + 1) * 32],
res_torch[:, i * 32:(i + 1) * 32])
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE,
reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("scalings",
[[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype,
scalings: list[int]):
qspec = FP4PerGroupSpec(
ch_axis=-1,
group_size=32,
scale_format="e8m0",
scale_calculation_mode="even",
is_dynamic=False,
).to_quantization_spec()
weight_quantizer = StaticScaledRealQuantizer(
qspec=qspec,
quantizer=None,
reorder=False,
real_quantized=True,
float_dtype=float_dtype,
device="cuda",
)
observer = qspec.observer_cls(qspec, device="cuda")
hidden_size = 512
shape = (11008, hidden_size)
w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2
# Make it so that different groups have different scales.
for i in range(hidden_size // 32):
w[:, i * 32:(i + 1) *
32] = w[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)]
observer(w)
scale, _ = observer._calculate_qparams()
weight_quantizer.scale = scale
w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
weight_quantizer.maybe_convert_and_transpose_scale()
scale = weight_quantizer.scale
out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)
out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)
assert torch.equal(out_hip, out_torch)