[Quant][Feature] Support online MXFP8 quantization for MoE and dense models (#35448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
This commit is contained in:
EdalatiAli
2026-03-16 18:07:39 -04:00
committed by GitHub
parent fd4d96302a
commit e5b807607c
10 changed files with 747 additions and 56 deletions

View File

@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""E2E tests for online MXFP8 quantization.
Loads a BF16 model with ``--quantization mxfp8`` (online quantization) and
compares log-probabilities against the same model served in BF16 without
quantization. This exercises the full pipeline: config parsing,
``Mxfp8OnlineLinearMethod``, ``Mxfp8OnlineMoEMethod``, weight loading,
online quantization / shuffling, and inference through ``apply_monolithic``.
Layer skipping (``modules_to_not_convert``) is configured in the model's
``config.json`` under ``quantization_config`` and is not tested here.
``example_prompts`` is a pytest fixture (from conftest.py) that loads 8
diverse prompts from ``tests/prompts/example.txt``.
"""
import pytest
from tests.quantization.utils import is_quant_method_supported
from ..utils import check_logprobs_close
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
MOE_MODEL = "Qwen/Qwen3-30B-A3B"
# A small dense model (no MoE) to validate the linear-only path.
DENSE_MODEL = "Qwen/Qwen3-0.6B"
MAX_MODEL_LEN = 1024
MAX_TOKENS = 4
NUM_LOG_PROBS = 8
@pytest.mark.skipif(
not is_quant_method_supported("mxfp8"),
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
)
@pytest.mark.quant_model
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
def test_mxfp8_logprobs(
vllm_runner,
example_prompts,
model: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Compare BF16 baseline logprobs against online MXFP8-quantized model.
Runs the same model twice -- once in BF16 (baseline) and once with
online MXFP8 quantization -- then checks that the top log-probabilities
are close. Only 4 tokens are generated to keep the test fast while
still catching numerical divergence.
"""
with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true")
with vllm_runner(
model,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
)
with vllm_runner(
model,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
quantization="mxfp8",
) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, MAX_TOKENS, NUM_LOG_PROBS
)
check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=test_outputs,
name_0="bf16",
name_1="mxfp8",
)
@pytest.mark.skipif(
not is_quant_method_supported("mxfp8"),
reason="mxfp8 is not supported on this GPU type (requires sm_100+).",
)
@pytest.mark.quant_model
@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"])
def test_mxfp8_generation(vllm_runner, model: str) -> None:
"""Smoke test: verify online MXFP8 model generates coherent text."""
prompt = "1 2 3 4 5"
with vllm_runner(
model,
enforce_eager=True,
quantization="mxfp8",
max_model_len=MAX_MODEL_LEN,
) as vllm_model:
output = vllm_model.generate_greedy([prompt], max_tokens=5)
generated = output[0][1]
assert len(generated) > len(prompt), (
f"MXFP8 model produced no new tokens. Output: {generated!r}"
)

View File

@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
kFp8StaticTensorSym,
kMxfp8Dynamic,
kMxfp8Static,
)
from vllm.platforms import current_platform
@@ -67,11 +69,54 @@ class TrtLlmFp8ExpertsBase:
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kMxfp8Static, kMxfp8Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) in [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kMxfp8Static, kMxfp8Dynamic),
]:
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
@@ -113,9 +158,10 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 block."""
"""Supports Fp8 block and MXFP8."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kMxfp8Static, kMxfp8Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@@ -159,6 +205,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
apply_router_weight_on_input: bool,
):
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
# Pack topk_ids and topk_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
@@ -175,6 +222,16 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
assert a1q_scale is not None
is_mxfp8 = self.quant_config.block_shape == [1, 32]
if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True
hidden_states_scale = a1q_scale
else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False
hidden_states_scale = a1q_scale.t().contiguous()
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
# output tensor in-place so we need to manually copy the result to the
# output tensor
@@ -183,7 +240,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
hidden_states_scale=hidden_states_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
@@ -197,8 +254,9 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
use_shuffled_weight=False,
use_shuffled_weight=use_shuffled_weight,
weight_layout=0,
fp8_quantization_type=fp8_quant_type,
# output=output,
)
output.copy_(result)
@@ -240,10 +298,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kMxfp8Static, kMxfp8Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@@ -256,7 +315,10 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
if (weight_key, activation_key) in [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kMxfp8Static, kMxfp8Dynamic),
]:
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
@@ -274,7 +336,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
else:
raise ValueError("Unsupported quantization scheme.")
def _apply_per_block(
def _apply_block_scale(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -291,32 +353,38 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
# Delay import for non-CUDA.
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
assert not apply_router_weight_on_input
assert activation == MoEActivation.SILU
assert self.topk <= global_num_experts
assert self.topk <= 10
assert global_num_experts % 4 == 0
assert self.quant_config.block_shape in [[128, 128], [1, 32]]
# Kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# TODO: fuse into the quant kernel.
assert a1q_scale is not None
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
assert self.topk <= global_num_experts
assert self.topk <= 10
assert global_num_experts % 4 == 0
assert self.quant_config.block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
# Kernel requires transposed hidden state scales
# TODO: fuse into the quant kernel.
assert a1q_scale is not None
a1q_scale_t = a1q_scale.t().contiguous()
is_mxfp8 = self.quant_config.block_shape == [1, 32]
if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8
use_shuffled_weight = True
hidden_states_scale = a1q_scale
else:
fp8_quant_type = Fp8QuantizationType.DeepSeekFp8
use_shuffled_weight = False
hidden_states_scale = a1q_scale.t().contiguous()
return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale_t,
hidden_states_scale=hidden_states_scale,
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
@@ -330,7 +398,8 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
local_num_experts=self.local_num_experts,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method_type,
use_shuffled_weight=False,
use_shuffled_weight=use_shuffled_weight,
fp8_quantization_type=fp8_quant_type,
)
def _apply_per_tensor(
@@ -409,7 +478,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
topk_group: int | None = None,
) -> torch.Tensor:
if self.quant_config.block_shape is not None:
return self._apply_per_block(
return self._apply_block_scale(
hidden_states,
w1,
w2,
@@ -441,6 +510,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
)
else:
raise NotImplementedError(
"Only per-block and per-tensor quantization are supported in "
f"{self.__class__.__name__}."
"Only per-block, per-tensor, and MXFP8 quantization are "
f"supported in {self.__class__.__name__}."
)

View File

@@ -444,7 +444,7 @@ def convert_to_fp8_moe_kernel_format(
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,
]:
w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_fi(
layer=layer,
w13=w13,
w2=w2,
@@ -512,6 +512,21 @@ def make_fp8_moe_quant_config(
g1_alphas=(w1_scale * a1_scale).squeeze(),
g2_alphas=(w2_scale * a2_scale).squeeze(),
)
# MXFP8 uses "mxfp8" quant_dtype so the prepare step dispatches to
# _mxfp8_e4m3_quantize rather than standard FP8 block quantization.
# Non-swizzled layout is required since the TRTLLM kernel expects
# scales in (num_tokens, hidden_dim // 32) format.
if block_shape == [1, 32]:
return FusedMoEQuantConfig.make(
"mxfp8",
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
is_nvfp4_scale_swizzled=False,
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,

View File

@@ -1,44 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
backend_to_kernel_cls,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kMxfp8Dynamic,
kMxfp8Static,
)
logger = init_logger(__name__)
_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset(
{
Fp8MoeBackend.FLASHINFER_TRTLLM,
}
)
class MxFp8MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = {
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
}
def _select_kernel_cls(
backend: Fp8MoeBackend,
config: FusedMoEConfig,
) -> type[mk.FusedMoEExperts]:
"""Select the first supported expert class for the MXFP8 config."""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
last_reason: str | None = None
for cls in backend_to_kernel_cls(backend):
supported, reason = cls.is_supported_config(
cls,
config,
kMxfp8Static,
kMxfp8Dynamic,
activation_format,
)
if supported:
return cls
last_reason = reason
raise ValueError(
f"No supported MXFP8 expert class for {backend.value}: {last_reason}"
)
def select_mxfp8_moe_backend(
config: FusedMoEConfig,
) -> MxFp8MoeBackend:
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
"""Select the MXFP8 MoE backend and the best expert class.
Returns:
A tuple of (fp8_backend, experts_cls).
"""
if config.is_lora_enabled:
raise NotImplementedError("LoRA is not supported for MXFP8 MoE.")
AVAILABLE_BACKENDS = [
MxFp8MoeBackend.FLASHINFER_TRTLLM,
]
runner_backend = config.moe_backend
if runner_backend != "auto":
mapping = {
"flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM,
}
if backend := mapping.get(runner_backend):
logger.info_once(
"Using '%s' MxFp8 MoE backend (user-requested).",
backend.value,
backend = _BACKEND_NAME_MAP.get(runner_backend)
if backend is None:
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for "
f"MXFP8 MoE. Expected one of "
f"{list(_BACKEND_NAME_MAP.keys())}."
)
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. "
f"Expected one of {list(mapping.keys())}."
logger.info_once(
"Using '%s' MxFp8 MoE backend (user-requested).",
backend.value,
)
return backend, _select_kernel_cls(backend, config)
# Auto-select: only one backend available for now.
backend = AVAILABLE_BACKENDS[0]
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend
# Auto-select: pick the first supported backend.
for backend in _SUPPORTED_BACKENDS:
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend, _select_kernel_cls(backend, config)
raise ValueError("No MXFP8 MoE backends available.")

View File

@@ -199,7 +199,7 @@ def _mxfp8_e4m3_quantize(
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
assert block_shape is None or block_shape == [1, 32]
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)

View File

@@ -31,6 +31,7 @@ QuantizationMethods = Literal[
"torchao",
"inc",
"mxfp4",
"mxfp8",
"petit_nvfp4",
"cpu_awq",
]
@@ -129,6 +130,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
)
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .mxfp8 import Mxfp8Config
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .torchao import TorchAOConfig
@@ -156,6 +158,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": INCConfig,
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"mxfp8": Mxfp8Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
}

View File

@@ -25,13 +25,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
MxFp8MoeBackend,
select_mxfp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
@@ -1712,8 +1712,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
self.quant_config = quant_config
assert self.quant_config.is_checkpoint_mxfp8_serialized
# Select MXFP8 MoE backend
self.mxfp8_backend = select_mxfp8_moe_backend(self.moe)
self.mxfp8_backend, _ = select_mxfp8_moe_backend(self.moe)
def create_weights(
self,
@@ -1943,7 +1942,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
@property
def is_monolithic(self) -> bool:
return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
return self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
def apply_monolithic(
self,
@@ -1956,7 +1955,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
Fp8QuantizationType,
)
assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM
assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
if layer.enable_eplb:
raise NotImplementedError(

View File

@@ -0,0 +1,354 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Online MXFP8 (microscaling FP8, block-32) quantization config and methods."""
from typing import Any
import torch
from torch.nn import Module
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
select_mxfp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8OnlineLinearMethod,
Fp8OnlineMoEMethod,
_copy_missing_attrs,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
mxfp8_e4m3_quantize,
swizzle_mxfp8_scale,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.model_loader.weight_utils import (
initialize_single_dummy_weight,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
class Mxfp8Config(Fp8Config):
"""Config class for online MXFP8 MoE quantization."""
def __init__(
self,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
) -> None:
if activation_scheme != "dynamic":
raise ValueError("mxfp8 only supports dynamic activation scheme.")
super().__init__(
is_checkpoint_fp8_serialized=False,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=None,
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp8"
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
activation_scheme = cls.get_from_keys_or(
config, ["activation_scheme"], "dynamic"
)
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedLinearMethod()
return Mxfp8OnlineLinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Mxfp8OnlineMoEMethod(self, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
"""Online MXFP8 linear method.
Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling
FP8 with block-32 scales) during weight loading.
Args:
quant_config: The MXFP8 quantization config.
"""
uses_meta_device: bool = True
def __init__(self, quant_config: "Mxfp8Config"):
self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype()
self.mxfp8_linear = Mxfp8LinearOp(self._select_backend())
logger.info_once(
"Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value
)
@staticmethod
def _select_backend() -> Mxfp8LinearBackend:
try:
from vllm.utils import flashinfer as fi
_ = fi.mm_mxfp8
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
except Exception:
logger.warning(
"FlashInfer mm_mxfp8 not available, "
"falling back to MXFP8 emulation backend."
)
return Mxfp8LinearBackend.EMULATION
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
raise ValueError(
f"MXFP8 requires input_size_per_partition "
f"({input_size_per_partition}) to be divisible by "
f"{MXFP8_BLOCK_SIZE}."
)
super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
initialize_single_dummy_weight(layer.weight)
weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())
if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
N, K = layer.weight.shape[0], layer.weight.shape[1]
weight_scale = swizzle_mxfp8_scale(weight_scale, N, K)
layer.input_scale = None
replace_parameter(layer, "weight", weight_fp8.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
layer._already_called_process_weights_after_loading = True
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.mxfp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
bias=bias,
)
class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
"""MoE method for online MXFP8 (block) quantization."""
uses_meta_device: bool = True
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
FusedMoEMethodBase.__init__(self, layer.moe_config)
self.quant_config = quant_config
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
self.weight_block_size = [1, MXFP8_BLOCK_SIZE]
self.block_quant = True
self.weight_scale_name = "weight_scale"
self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe)
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if (
hidden_size % MXFP8_BLOCK_SIZE != 0
or intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0
):
raise ValueError(
"Online MXFP8 MoE requires hidden/intermediate sizes divisible "
f"by {MXFP8_BLOCK_SIZE}."
)
super().create_weights(
layer=layer,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size_per_partition=intermediate_size_per_partition,
params_dtype=params_dtype,
**extra_weight_attrs,
)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // MXFP8_BLOCK_SIZE,
dtype=torch.uint8,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.weight_block_size = [1, MXFP8_BLOCK_SIZE]
def _quantize_mxfp8_moe_weight(
self, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
num_batches = weight.size(0)
w_quant = []
w_scales = []
for i in range(num_batches):
mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize(
weight[i], is_sf_swizzled_layout=False
)
w_quant.append(mx_fp8_quant)
w_scales.append(mx_fp8_scale)
return torch.stack(w_quant), torch.stack(w_scales)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight)
w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight)
self._setup_kernel(
layer,
w13,
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
)
layer._already_called_process_weights_after_loading = True

View File

@@ -305,6 +305,81 @@ def align_fp8_moe_weights_for_fi(
return padded_w13, padded_w2, padded_intermediate
def _shuffle_mxfp8_moe_weights(
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
is_gated: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel.
Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py:
1. reorder_rows_for_gated_act_gemm (interleave gate/up rows)
2. shuffle_matrix_a (weight data layout shuffle)
3. shuffle_matrix_sf_a (scale factor layout shuffle)
"""
from flashinfer import (
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
epilogue_tile_m = 128
num_experts = w13.shape[0]
intermediate_size = w13.shape[1] // 2
hidden_size = w13.shape[2]
w13_interleaved: list[torch.Tensor] = []
w13_scale_interleaved: list[torch.Tensor] = []
for i in range(num_experts):
if is_gated:
w13_interleaved.append(
reorder_rows_for_gated_act_gemm(
w13[i].reshape(2 * intermediate_size, -1)
)
)
w13_scale_interleaved.append(
reorder_rows_for_gated_act_gemm(
w13_scale[i].reshape(2 * intermediate_size, -1)
)
)
else:
w13_interleaved.append(w13[i])
w13_scale_interleaved.append(w13_scale[i])
w13_shuffled: list[torch.Tensor] = []
w2_shuffled: list[torch.Tensor] = []
w13_scale_shuffled: list[torch.Tensor] = []
w2_scale_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
w13_shuffled.append(
shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
)
w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m))
w13_scale_shuffled.append(
shuffle_matrix_sf_a(
w13_scale_interleaved[i]
.view(torch.uint8)
.reshape(2 * intermediate_size, -1),
epilogue_tile_m,
)
)
w2_scale_shuffled.append(
shuffle_matrix_sf_a(
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
epilogue_tile_m,
)
)
w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape)
w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape)
return w13_out, w2_out, w13_scale_out, w2_scale_out
def prepare_fp8_moe_layer_for_fi(
layer: torch.nn.Module,
w13: torch.Tensor,
@@ -314,7 +389,7 @@ def prepare_fp8_moe_layer_for_fi(
w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor | None,
is_trtllm: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert Fp8 MoE weights to flashinfer kernel format
@@ -329,10 +404,33 @@ def prepare_fp8_moe_layer_for_fi(
block_quant = (
hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
)
is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8
is_gated = layer.activation.is_gated
# MXFP8 TRT-LLM requires W31 swap + reorder + shuffle.
if is_mxfp8 and is_trtllm:
# FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores
# [gate; up]. Swap both weights and scales before interleaving.
if layer.moe_config.is_act_and_mul:
w13 = swap_w13_to_w31(w13)
# Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight;
# reshape to 3D so swap_w13_to_w31 can flip the two halves,
# then flatten back.
if w13_scale.ndim == 2:
num_rows = w13.shape[1] # 2 * intermediate_size
w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1)
w13_scale = swap_w13_to_w31(w13_scale)
w13_scale = w13_scale.reshape(w13_scale.shape[0], -1)
else:
w13_scale = swap_w13_to_w31(w13_scale)
w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights(
w13, w2, w13_scale, w2_scale, is_gated
)
return w13, w2, w13_scale, w2_scale
# Some FI MoE kernels require internal alignment of 16
# for the gate-up proj. Pad the weights to respect this.
is_gated = layer.activation.is_gated
if not block_quant:
min_alignment = 16 if is_gated else 128
w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
@@ -369,4 +467,4 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE)
return w13, w2, w13_scale
return w13, w2, w13_scale, w2_scale

View File

@@ -149,6 +149,12 @@ kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True)
kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128))
kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True)
kMxfp8StaticScale = ScaleDesc(torch.uint8, True, GroupShape(1, 32))
kMxfp8Static = QuantKey(FP8_DTYPE, kMxfp8StaticScale, symmetric=True)
kMxfp8DynamicScale = ScaleDesc(torch.uint8, False, GroupShape(1, 32))
kMxfp8Dynamic = QuantKey(FP8_DTYPE, kMxfp8DynamicScale, symmetric=True)
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)