[Quant][Feature] Support online MXFP8 quantization for MoE and dense models (#35448)
Signed-off-by: EdalatiAli <aliedalati@cohere.com>
This commit is contained in:
104
tests/models/quantization/test_mxfp8.py
Normal file
104
tests/models/quantization/test_mxfp8.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""E2E tests for online MXFP8 quantization.
|
||||
|
||||
Loads a BF16 model with ``--quantization mxfp8`` (online quantization) and
|
||||
compares log-probabilities against the same model served in BF16 without
|
||||
quantization. This exercises the full pipeline: config parsing,
|
||||
``Mxfp8OnlineLinearMethod``, ``Mxfp8OnlineMoEMethod``, weight loading,
|
||||
online quantization / shuffling, and inference through ``apply_monolithic``.
|
||||
|
||||
Layer skipping (``modules_to_not_convert``) is configured in the model's
|
||||
``config.json`` under ``quantization_config`` and is not tested here.
|
||||
|
||||
``example_prompts`` is a pytest fixture (from conftest.py) that loads 8
|
||||
diverse prompts from ``tests/prompts/example.txt``.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ..utils import check_logprobs_close
|
||||
|
||||
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
|
||||
MOE_MODEL = "Qwen/Qwen3-30B-A3B"
|
||||
# A small dense model (no MoE) to validate the linear-only path.
|
||||
DENSE_MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
MAX_TOKENS = 4
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("mxfp8"),
|
||||
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
|
||||
)
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
|
||||
def test_mxfp8_logprobs(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Compare BF16 baseline logprobs against online MXFP8-quantized model.
|
||||
|
||||
Runs the same model twice -- once in BF16 (baseline) and once with
|
||||
online MXFP8 quantization -- then checks that the top log-probabilities
|
||||
are close. Only 4 tokens are generated to keep the test fast while
|
||||
still catching numerical divergence.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
enforce_eager=True,
|
||||
quantization="mxfp8",
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=test_outputs,
|
||||
name_0="bf16",
|
||||
name_1="mxfp8",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("mxfp8"),
|
||||
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
|
||||
)
|
||||
@pytest.mark.quant_model
|
||||
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
|
||||
def test_mxfp8_generation(vllm_runner, model: str) -> None:
|
||||
"""Smoke test: verify online MXFP8 model generates coherent text."""
|
||||
prompt = "1 2 3 4 5"
|
||||
with vllm_runner(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
quantization="mxfp8",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
) as vllm_model:
|
||||
output = vllm_model.generate_greedy([prompt], max_tokens=5)
|
||||
|
||||
generated = output[0][1]
|
||||
assert len(generated) > len(prompt), (
|
||||
f"MXFP8 model produced no new tokens. Output: {generated!r}"
|
||||
)
|
||||
Reference in New Issue
Block a user