[Feature][OCP MX] Support mxfp6 and mixed mxfp6-mxfp4 (#21166)

This commit is contained in:
fxmarty-amd
2025-10-07 15:35:26 +02:00
committed by GitHub
parent 08d26a1b7e
commit 41f1cf38f2
18 changed files with 656 additions and 180 deletions

View File

@@ -11,6 +11,7 @@ import importlib.metadata
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Optional
import huggingface_hub
import lm_eval
@@ -148,39 +149,93 @@ def test_quark_fp8_parity(vllm_runner):
@dataclass
class ModelCase:
model_id: str
tp: int
@dataclass
class GSM8KAccuracyTestConfig:
class AccuracyTestConfig:
model_name: str
excepted_value: float
def get_model_args(self) -> str:
return (
f"pretrained={self.model_name},"
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
)
def get_model_args(
self,
tp_size: int,
model_max_len: Optional[int] = None,
kwargs: Optional[dict] = None,
) -> dict:
if kwargs is None:
kwargs = {}
model_args = {
"pretrained": self.model_name,
"dtype": "auto",
"add_bos_token": True,
"tensor_parallel_size": tp_size,
"gpu_memory_utilization": 0.7,
**kwargs,
}
if model_max_len is not None:
model_args["max_model_len"] = model_max_len
return model_args
ACCURACY_CONFIGS = [
GSM8K_ACCURACY_CONFIGS = [
# Private model.
GSM8KAccuracyTestConfig(
AccuracyTestConfig(
model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
excepted_value=0.96,
),
]
WIKITEXT_ACCURACY_CONFIGS = [
AccuracyTestConfig(
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
excepted_value=11.3,
),
AccuracyTestConfig(
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
excepted_value=10.6,
),
AccuracyTestConfig(
model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
),
]
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
if torch.cuda.device_count() < tp_size:
pytest.skip(
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
)
task = "wikitext"
rtol = 0.1
# Smaller cuda_graph_sizes to speed up the test.
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
),
tasks=task,
batch_size=64,
)
EXPECTED_VALUE = config.excepted_value
measured_value = results["results"][task]["word_perplexity,none"]
assert (
measured_value < EXPECTED_VALUE + rtol
and measured_value > EXPECTED_VALUE - rtol
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.skipif(
not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.",
)
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
if torch.cuda.device_count() < 8:
pytest.skip(
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
@@ -193,7 +248,7 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
results = lm_eval.simple_evaluate(
model="vllm",
model_args=config.get_model_args(),
model_args=config.get_model_args(tp_size=8, model_max_len=38768),
tasks=task,
batch_size=64,
num_fewshot=8,