[Feature][OCP MX] Support mxfp6 and mixed mxfp6-mxfp4 (#21166)
This commit is contained in:
@@ -11,6 +11,7 @@ import importlib.metadata
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import Optional
|
||||
|
||||
import huggingface_hub
|
||||
import lm_eval
|
||||
@@ -148,39 +149,93 @@ def test_quark_fp8_parity(vllm_runner):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
model_id: str
|
||||
tp: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class GSM8KAccuracyTestConfig:
|
||||
class AccuracyTestConfig:
|
||||
model_name: str
|
||||
excepted_value: float
|
||||
|
||||
def get_model_args(self) -> str:
|
||||
return (
|
||||
f"pretrained={self.model_name},"
|
||||
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
|
||||
)
|
||||
def get_model_args(
|
||||
self,
|
||||
tp_size: int,
|
||||
model_max_len: Optional[int] = None,
|
||||
kwargs: Optional[dict] = None,
|
||||
) -> dict:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
model_args = {
|
||||
"pretrained": self.model_name,
|
||||
"dtype": "auto",
|
||||
"add_bos_token": True,
|
||||
"tensor_parallel_size": tp_size,
|
||||
"gpu_memory_utilization": 0.7,
|
||||
**kwargs,
|
||||
}
|
||||
if model_max_len is not None:
|
||||
model_args["max_model_len"] = model_max_len
|
||||
|
||||
return model_args
|
||||
|
||||
|
||||
ACCURACY_CONFIGS = [
|
||||
GSM8K_ACCURACY_CONFIGS = [
|
||||
# Private model.
|
||||
GSM8KAccuracyTestConfig(
|
||||
AccuracyTestConfig(
|
||||
model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
|
||||
excepted_value=0.96,
|
||||
),
|
||||
]
|
||||
|
||||
WIKITEXT_ACCURACY_CONFIGS = [
|
||||
AccuracyTestConfig(
|
||||
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
|
||||
excepted_value=11.3,
|
||||
),
|
||||
AccuracyTestConfig(
|
||||
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
|
||||
excepted_value=10.6,
|
||||
),
|
||||
AccuracyTestConfig(
|
||||
model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
|
||||
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
||||
if torch.cuda.device_count() < tp_size:
|
||||
pytest.skip(
|
||||
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
|
||||
)
|
||||
|
||||
task = "wikitext"
|
||||
rtol = 0.1
|
||||
|
||||
# Smaller cuda_graph_sizes to speed up the test.
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=config.get_model_args(
|
||||
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
|
||||
),
|
||||
tasks=task,
|
||||
batch_size=64,
|
||||
)
|
||||
|
||||
EXPECTED_VALUE = config.excepted_value
|
||||
measured_value = results["results"][task]["word_perplexity,none"]
|
||||
assert (
|
||||
measured_value < EXPECTED_VALUE + rtol
|
||||
and measured_value > EXPECTED_VALUE - rtol
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
|
||||
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
|
||||
@pytest.mark.skipif(
|
||||
not HF_HUB_AMD_ORG_ACCESS,
|
||||
reason="Read access to huggingface.co/amd is required for this test.",
|
||||
)
|
||||
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
|
||||
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
|
||||
if torch.cuda.device_count() < 8:
|
||||
pytest.skip(
|
||||
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
|
||||
@@ -193,7 +248,7 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=config.get_model_args(),
|
||||
model_args=config.get_model_args(tp_size=8, model_max_len=38768),
|
||||
tasks=task,
|
||||
batch_size=64,
|
||||
num_fewshot=8,
|
||||
|
||||
Reference in New Issue
Block a user