diff --git a/tests/models/quantization/test_mxfp8.py b/tests/models/quantization/test_mxfp8.py new file mode 100644 index 000000000..2cb0f2008 --- /dev/null +++ b/tests/models/quantization/test_mxfp8.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""E2E tests for online MXFP8 quantization. + +Loads a BF16 model with ``--quantization mxfp8`` (online quantization) and +compares log-probabilities against the same model served in BF16 without +quantization. This exercises the full pipeline: config parsing, +``Mxfp8OnlineLinearMethod``, ``Mxfp8OnlineMoEMethod``, weight loading, +online quantization / shuffling, and inference through ``apply_monolithic``. + +Layer skipping (``modules_to_not_convert``) is configured in the model's +``config.json`` under ``quantization_config`` and is not tested here. + +``example_prompts`` is a pytest fixture (from conftest.py) that loads 8 +diverse prompts from ``tests/prompts/example.txt``. +""" + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from ..utils import check_logprobs_close + +# A small MoE model that fits on a single GPU and has both linear + MoE layers. +MOE_MODEL = "Qwen/Qwen3-30B-A3B" +# A small dense model (no MoE) to validate the linear-only path. +DENSE_MODEL = "Qwen/Qwen3-0.6B" + +MAX_MODEL_LEN = 1024 +MAX_TOKENS = 4 +NUM_LOG_PROBS = 8 + + +@pytest.mark.skipif( + not is_quant_method_supported("mxfp8"), + reason="mxfp8 is not supported on this GPU type (requires sm_100+).", +) +@pytest.mark.quant_model +@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"]) +def test_mxfp8_logprobs( + vllm_runner, + example_prompts, + model: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Compare BF16 baseline logprobs against online MXFP8-quantized model. + + Runs the same model twice -- once in BF16 (baseline) and once with + online MXFP8 quantization -- then checks that the top log-probabilities + are close. Only 4 tokens are generated to keep the test fast while + still catching numerical divergence. + """ + with monkeypatch.context() as m: + m.setenv("TOKENIZERS_PARALLELISM", "true") + + with vllm_runner( + model, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, MAX_TOKENS, NUM_LOG_PROBS + ) + + with vllm_runner( + model, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + quantization="mxfp8", + ) as vllm_model: + test_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, MAX_TOKENS, NUM_LOG_PROBS + ) + + check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=test_outputs, + name_0="bf16", + name_1="mxfp8", + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("mxfp8"), + reason="mxfp8 is not supported on this GPU type (requires sm_100+).", +) +@pytest.mark.quant_model +@pytest.mark.parametrize("model", [DENSE_MODEL, MOE_MODEL], ids=["dense", "moe"]) +def test_mxfp8_generation(vllm_runner, model: str) -> None: + """Smoke test: verify online MXFP8 model generates coherent text.""" + prompt = "1 2 3 4 5" + with vllm_runner( + model, + enforce_eager=True, + quantization="mxfp8", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.generate_greedy([prompt], max_tokens=5) + + generated = output[0][1] + assert len(generated) > len(prompt), ( + f"MXFP8 model produced no new tokens. Output: {generated!r}" + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 1c86702e9..74096ef6e 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8Dynamic128Sym, kFp8Static128BlockSym, kFp8StaticTensorSym, + kMxfp8Dynamic, + kMxfp8Static, ) from vllm.platforms import current_platform @@ -67,11 +69,54 @@ class TrtLlmFp8ExpertsBase: """Does not support non-gated MoE (i.e. Nanotron-3-Nano).""" return True + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Fp8 per-tensor, Fp8 block, and MXFP8.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + (kMxfp8Static, kMxfp8Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + @staticmethod def _supports_activation(activation: MoEActivation) -> bool: """Supports only SiLU and RELU^2 non-gated activation.""" return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Monolithic kernels need to express router support.""" + # NOTE(dbari): TopK routing could also be enabled, but need to validate models + # NOTE(dbari): Default is not implemented and should not be enabled until it is + if (weight_key, activation_key) in [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kMxfp8Static, kMxfp8Dynamic), + ]: + # NOTE(rob): potentially allow others here. This is a conservative list. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): + # NOTE(dbari): as above, potentially allow others here. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Llama4, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + else: + raise ValueError("Unsupported quantization scheme.") + @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """Monolithic kernel so only use with naive DP/EP and TP.""" @@ -113,9 +158,10 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - """Supports Fp8 block.""" + """Supports Fp8 block and MXFP8.""" SUPPORTED_W_A = [ (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kMxfp8Static, kMxfp8Dynamic), ] return (weight_key, activation_key) in SUPPORTED_W_A @@ -159,6 +205,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): apply_router_weight_on_input: bool, ): import flashinfer + from flashinfer.fused_moe import Fp8QuantizationType # Pack topk_ids and topk_weights into single tensor # Format: (expert_id << 16) | (weight_bf16.view(int16)) @@ -175,6 +222,16 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): assert a1q_scale is not None + is_mxfp8 = self.quant_config.block_shape == [1, 32] + if is_mxfp8: + fp8_quant_type = Fp8QuantizationType.MxFp8 + use_shuffled_weight = True + hidden_states_scale = a1q_scale + else: + fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 + use_shuffled_weight = False + hidden_states_scale = a1q_scale.t().contiguous() + # `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the # output tensor in-place so we need to manually copy the result to the # output tensor @@ -183,7 +240,7 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): topk_ids=packed_topk_ids, routing_bias=None, hidden_states=hidden_states, - hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr] + hidden_states_scale=hidden_states_scale, gemm1_weights=w1, gemm1_weights_scale=self.quant_config.w1_scale, gemm2_weights=w2, @@ -197,8 +254,9 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): local_num_experts=self.local_num_experts, routed_scaling_factor=None, routing_method_type=1, - use_shuffled_weight=False, + use_shuffled_weight=use_shuffled_weight, weight_layout=0, + fp8_quantization_type=fp8_quant_type, # output=output, ) output.copy_(result) @@ -240,10 +298,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit weight_key: QuantKey | None, activation_key: QuantKey | None, ) -> bool: - """Supports Fp8 per-tensor and Fp8 block.""" + """Supports Fp8 per-tensor, Fp8 block, and MXFP8.""" SUPPORTED_W_A = [ (kFp8Static128BlockSym, kFp8Dynamic128Sym), (kFp8StaticTensorSym, kFp8StaticTensorSym), + (kMxfp8Static, kMxfp8Dynamic), ] return (weight_key, activation_key) in SUPPORTED_W_A @@ -256,7 +315,10 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit """Monolithic kernels need to express router support.""" # NOTE(dbari): TopK routing could also be enabled, but need to validate models # NOTE(dbari): Default is not implemented and should not be enabled until it is - if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): + if (weight_key, activation_key) in [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kMxfp8Static, kMxfp8Dynamic), + ]: # NOTE(rob): potentially allow others here. This is a conservative list. return routing_method in [ RoutingMethodType.DeepSeekV3, @@ -274,7 +336,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit else: raise ValueError("Unsupported quantization scheme.") - def _apply_per_block( + def _apply_block_scale( self, hidden_states: torch.Tensor, w1: torch.Tensor, @@ -291,32 +353,38 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit routed_scaling_factor: float | None = None, topk_group: int | None = None, ) -> torch.Tensor: - # Delay import for non-CUDA. import flashinfer + from flashinfer.fused_moe import Fp8QuantizationType assert not apply_router_weight_on_input assert activation == MoEActivation.SILU + assert self.topk <= global_num_experts + assert self.topk <= 10 + assert global_num_experts % 4 == 0 + assert self.quant_config.block_shape in [[128, 128], [1, 32]] + # Kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 + # TODO: fuse into the quant kernel. + assert a1q_scale is not None if self.routing_method_type == RoutingMethodType.DeepSeekV3: router_logits = router_logits.to(torch.float32) - assert self.topk <= global_num_experts - assert self.topk <= 10 - assert global_num_experts % 4 == 0 - assert self.quant_config.block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 512 - assert global_num_experts <= 512 - - # Kernel requires transposed hidden state scales - # TODO: fuse into the quant kernel. - assert a1q_scale is not None - a1q_scale_t = a1q_scale.t().contiguous() + is_mxfp8 = self.quant_config.block_shape == [1, 32] + if is_mxfp8: + fp8_quant_type = Fp8QuantizationType.MxFp8 + use_shuffled_weight = True + hidden_states_scale = a1q_scale + else: + fp8_quant_type = Fp8QuantizationType.DeepSeekFp8 + use_shuffled_weight = False + hidden_states_scale = a1q_scale.t().contiguous() return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( routing_logits=router_logits, routing_bias=e_score_correction_bias, hidden_states=hidden_states, - hidden_states_scale=a1q_scale_t, + hidden_states_scale=hidden_states_scale, gemm1_weights=w1, gemm1_weights_scale=self.quant_config.w1_scale, gemm2_weights=w2, @@ -330,7 +398,8 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit local_num_experts=self.local_num_experts, routed_scaling_factor=routed_scaling_factor, routing_method_type=self.routing_method_type, - use_shuffled_weight=False, + use_shuffled_weight=use_shuffled_weight, + fp8_quantization_type=fp8_quant_type, ) def _apply_per_tensor( @@ -409,7 +478,7 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit topk_group: int | None = None, ) -> torch.Tensor: if self.quant_config.block_shape is not None: - return self._apply_per_block( + return self._apply_block_scale( hidden_states, w1, w2, @@ -441,6 +510,6 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ) else: raise NotImplementedError( - "Only per-block and per-tensor quantization are supported in " - f"{self.__class__.__name__}." + "Only per-block, per-tensor, and MXFP8 quantization are " + f"supported in {self.__class__.__name__}." ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 48ca03f66..a63c02663 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -444,7 +444,7 @@ def convert_to_fp8_moe_kernel_format( Fp8MoeBackend.FLASHINFER_CUTLASS, Fp8MoeBackend.FLASHINFER_TRTLLM, ]: - w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi( + w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_fi( layer=layer, w13=w13, w2=w2, @@ -512,6 +512,21 @@ def make_fp8_moe_quant_config( g1_alphas=(w1_scale * a1_scale).squeeze(), g2_alphas=(w2_scale * a2_scale).squeeze(), ) + # MXFP8 uses "mxfp8" quant_dtype so the prepare step dispatches to + # _mxfp8_e4m3_quantize rather than standard FP8 block quantization. + # Non-swizzled layout is required since the TRTLLM kernel expects + # scales in (num_tokens, hidden_dim // 32) format. + if block_shape == [1, 32]: + return FusedMoEQuantConfig.make( + "mxfp8", + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + is_nvfp4_scale_swizzled=False, + ) + # All other backends use normal config. return fp8_w8a8_moe_quant_config( w1_scale=w1_scale, diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py index 49406ba93..ed3af4b5a 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py @@ -1,44 +1,87 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, + backend_to_kernel_cls, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kMxfp8Dynamic, + kMxfp8Static, +) logger = init_logger(__name__) +_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset( + { + Fp8MoeBackend.FLASHINFER_TRTLLM, + } +) -class MxFp8MoeBackend(Enum): - FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" +_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = { + "flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM, +} + + +def _select_kernel_cls( + backend: Fp8MoeBackend, + config: FusedMoEConfig, +) -> type[mk.FusedMoEExperts]: + """Select the first supported expert class for the MXFP8 config.""" + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if config.moe_parallel_config.use_batched_activation_format + else mk.FusedMoEActivationFormat.Standard + ) + last_reason: str | None = None + for cls in backend_to_kernel_cls(backend): + supported, reason = cls.is_supported_config( + cls, + config, + kMxfp8Static, + kMxfp8Dynamic, + activation_format, + ) + if supported: + return cls + last_reason = reason + raise ValueError( + f"No supported MXFP8 expert class for {backend.value}: {last_reason}" + ) def select_mxfp8_moe_backend( config: FusedMoEConfig, -) -> MxFp8MoeBackend: +) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: + """Select the MXFP8 MoE backend and the best expert class. + + Returns: + A tuple of (fp8_backend, experts_cls). + """ if config.is_lora_enabled: raise NotImplementedError("LoRA is not supported for MXFP8 MoE.") - AVAILABLE_BACKENDS = [ - MxFp8MoeBackend.FLASHINFER_TRTLLM, - ] - runner_backend = config.moe_backend if runner_backend != "auto": - mapping = { - "flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM, - } - if backend := mapping.get(runner_backend): - logger.info_once( - "Using '%s' MxFp8 MoE backend (user-requested).", - backend.value, + backend = _BACKEND_NAME_MAP.get(runner_backend) + if backend is None: + raise ValueError( + f"moe_backend='{runner_backend}' is not supported for " + f"MXFP8 MoE. Expected one of " + f"{list(_BACKEND_NAME_MAP.keys())}." ) - return backend - raise ValueError( - f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. " - f"Expected one of {list(mapping.keys())}." + logger.info_once( + "Using '%s' MxFp8 MoE backend (user-requested).", + backend.value, ) + return backend, _select_kernel_cls(backend, config) - # Auto-select: only one backend available for now. - backend = AVAILABLE_BACKENDS[0] - logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value) - return backend + # Auto-select: pick the first supported backend. + for backend in _SUPPORTED_BACKENDS: + logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value) + return backend, _select_kernel_cls(backend, config) + + raise ValueError("No MXFP8 MoE backends available.") diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 019e408c1..4adb7f1cf 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -199,7 +199,7 @@ def _mxfp8_e4m3_quantize( ) -> tuple[torch.Tensor, torch.Tensor]: assert A_scale is None assert not per_act_token_quant - assert block_shape is None + assert block_shape is None or block_shape == [1, 32] return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 2fb54e775..e08a6456a 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -31,6 +31,7 @@ QuantizationMethods = Literal[ "torchao", "inc", "mxfp4", + "mxfp8", "petit_nvfp4", "cpu_awq", ] @@ -129,6 +130,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ) from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config + from .mxfp8 import Mxfp8Config from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config from .torchao import TorchAOConfig @@ -156,6 +158,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "auto-round": INCConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, + "mxfp8": Mxfp8Config, "petit_nvfp4": PetitNvFp4Config, "cpu_awq": CPUAWQConfig, } diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 640580da6..78644f74d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -25,13 +25,13 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import ( - MxFp8MoeBackend, select_mxfp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( @@ -1712,8 +1712,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): self.quant_config = quant_config assert self.quant_config.is_checkpoint_mxfp8_serialized - # Select MXFP8 MoE backend - self.mxfp8_backend = select_mxfp8_moe_backend(self.moe) + self.mxfp8_backend, _ = select_mxfp8_moe_backend(self.moe) def create_weights( self, @@ -1943,7 +1942,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): @property def is_monolithic(self) -> bool: - return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM + return self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM def apply_monolithic( self, @@ -1956,7 +1955,7 @@ class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): Fp8QuantizationType, ) - assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM + assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM if layer.enable_eplb: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/mxfp8.py b/vllm/model_executor/layers/quantization/mxfp8.py new file mode 100644 index 000000000..5b4564bea --- /dev/null +++ b/vllm/model_executor/layers/quantization/mxfp8.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Online MXFP8 (microscaling FP8, block-32) quantization config and methods.""" + +from typing import Any + +import torch +from torch.nn import Module + +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import ( + select_mxfp8_moe_backend, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8Config, + Fp8KVCacheMethod, + Fp8OnlineLinearMethod, + Fp8OnlineMoEMethod, + _copy_missing_attrs, +) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + MXFP8_BLOCK_SIZE, + Mxfp8LinearBackend, + Mxfp8LinearOp, + mxfp8_e4m3_quantize, + swizzle_mxfp8_scale, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, +) +from vllm.model_executor.model_loader.weight_utils import ( + initialize_single_dummy_weight, +) +from vllm.model_executor.parameter import ModelWeightParameter +from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class Mxfp8Config(Fp8Config): + """Config class for online MXFP8 MoE quantization.""" + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: list[str] | None = None, + ) -> None: + if activation_scheme != "dynamic": + raise ValueError("mxfp8 only supports dynamic activation scheme.") + super().__init__( + is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=None, + ) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp8" + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config": + activation_scheme = cls.get_from_keys_or( + config, ["activation_scheme"], "dynamic" + ) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls( + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + skip_with_substr=True, + ): + return UnquantizedLinearMethod() + return Mxfp8OnlineLinearMethod(self) + elif isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + skip_with_substr=True, + ): + return UnquantizedFusedMoEMethod(layer.moe_config) + return Mxfp8OnlineMoEMethod(self, layer) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + +class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): + """Online MXFP8 linear method. + Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling + FP8 with block-32 scales) during weight loading. + + Args: + quant_config: The MXFP8 quantization config. + """ + + uses_meta_device: bool = True + + def __init__(self, quant_config: "Mxfp8Config"): + self.quant_config = quant_config + self.out_dtype = torch.get_default_dtype() + self.mxfp8_linear = Mxfp8LinearOp(self._select_backend()) + logger.info_once( + "Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value + ) + + @staticmethod + def _select_backend() -> Mxfp8LinearBackend: + try: + from vllm.utils import flashinfer as fi + + _ = fi.mm_mxfp8 + return Mxfp8LinearBackend.FLASHINFER_CUTLASS + except Exception: + logger.warning( + "FlashInfer mm_mxfp8 not available, " + "falling back to MXFP8 emulation backend." + ) + return Mxfp8LinearBackend.EMULATION + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if input_size_per_partition % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + f"MXFP8 requires input_size_per_partition " + f"({input_size_per_partition}) to be divisible by " + f"{MXFP8_BLOCK_SIZE}." + ) + + super().create_weights( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + + weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous()) + + if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS: + N, K = layer.weight.shape[0], layer.weight.shape[1] + weight_scale = swizzle_mxfp8_scale(weight_scale, N, K) + + layer.input_scale = None + replace_parameter(layer, "weight", weight_fp8.data) + replace_parameter(layer, "weight_scale", weight_scale.data) + + layer._already_called_process_weights_after_loading = True + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.mxfp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + bias=bias, + ) + + +class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): + """MoE method for online MXFP8 (block) quantization.""" + + uses_meta_device: bool = True + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + FusedMoEMethodBase.__init__(self, layer.moe_config) + self.quant_config = quant_config + assert not quant_config.is_checkpoint_fp8_serialized + assert quant_config.activation_scheme == "dynamic" + + self.weight_block_size = [1, MXFP8_BLOCK_SIZE] + self.block_quant = True + self.weight_scale_name = "weight_scale" + + self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe) + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if ( + hidden_size % MXFP8_BLOCK_SIZE != 0 + or intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0 + ): + raise ValueError( + "Online MXFP8 MoE requires hidden/intermediate sizes divisible " + f"by {MXFP8_BLOCK_SIZE}." + ) + + super().create_weights( + layer=layer, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size_per_partition=intermediate_size_per_partition, + params_dtype=params_dtype, + **extra_weight_attrs, + ) + + w13_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // MXFP8_BLOCK_SIZE, + dtype=torch.uint8, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // MXFP8_BLOCK_SIZE, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + layer.weight_block_size = [1, MXFP8_BLOCK_SIZE] + + def _quantize_mxfp8_moe_weight( + self, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales).""" + num_batches = weight.size(0) + w_quant = [] + w_scales = [] + for i in range(num_batches): + mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize( + weight[i], is_sf_swizzled_layout=False + ) + w_quant.append(mx_fp8_quant) + w_scales.append(mx_fp8_scale) + + return torch.stack(w_quant), torch.stack(w_scales) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + if layer.w13_weight.device == torch.device("meta"): + w13_weight = torch.nn.Parameter( + torch.empty_like(layer.w13_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w13_weight, {"weight_loader": layer.w13_weight.weight_loader} + ) + _copy_missing_attrs(layer.w13_weight, w13_weight) + layer.register_parameter("w13_weight", w13_weight) + initialize_single_dummy_weight(layer.w13_weight) + if layer.w2_weight.device == torch.device("meta"): + w2_weight = torch.nn.Parameter( + torch.empty_like(layer.w2_weight, device=layer._load_device), + requires_grad=False, + ) + set_weight_attrs( + w2_weight, {"weight_loader": layer.w2_weight.weight_loader} + ) + _copy_missing_attrs(layer.w2_weight, w2_weight) + layer.register_parameter("w2_weight", w2_weight) + initialize_single_dummy_weight(layer.w2_weight) + + fp8_dtype = current_platform.fp8_dtype() + w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) + w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + + w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight) + w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight) + + self._setup_kernel( + layer, + w13, + w2, + w13_scale, + w2_scale, + layer.w13_input_scale, + layer.w2_input_scale, + ) + + layer._already_called_process_weights_after_loading = True diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 322b3a6e8..271bcf168 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -305,6 +305,81 @@ def align_fp8_moe_weights_for_fi( return padded_w13, padded_w2, padded_intermediate +def _shuffle_mxfp8_moe_weights( + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + is_gated: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Preprocess MXFP8 weights and scales for the FlashInfer TRT-LLM kernel. + + Following flashinfer/tests/moe/test_trtllm_gen_fused_moe.py: + 1. reorder_rows_for_gated_act_gemm (interleave gate/up rows) + 2. shuffle_matrix_a (weight data layout shuffle) + 3. shuffle_matrix_sf_a (scale factor layout shuffle) + """ + from flashinfer import ( + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) + + epilogue_tile_m = 128 + num_experts = w13.shape[0] + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] + + w13_interleaved: list[torch.Tensor] = [] + w13_scale_interleaved: list[torch.Tensor] = [] + for i in range(num_experts): + if is_gated: + w13_interleaved.append( + reorder_rows_for_gated_act_gemm( + w13[i].reshape(2 * intermediate_size, -1) + ) + ) + w13_scale_interleaved.append( + reorder_rows_for_gated_act_gemm( + w13_scale[i].reshape(2 * intermediate_size, -1) + ) + ) + else: + w13_interleaved.append(w13[i]) + w13_scale_interleaved.append(w13_scale[i]) + + w13_shuffled: list[torch.Tensor] = [] + w2_shuffled: list[torch.Tensor] = [] + w13_scale_shuffled: list[torch.Tensor] = [] + w2_scale_shuffled: list[torch.Tensor] = [] + for i in range(num_experts): + w13_shuffled.append( + shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m) + ) + w2_shuffled.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)) + w13_scale_shuffled.append( + shuffle_matrix_sf_a( + w13_scale_interleaved[i] + .view(torch.uint8) + .reshape(2 * intermediate_size, -1), + epilogue_tile_m, + ) + ) + w2_scale_shuffled.append( + shuffle_matrix_sf_a( + w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), + epilogue_tile_m, + ) + ) + + w13_out = torch.stack(w13_shuffled).view(torch.float8_e4m3fn) + w2_out = torch.stack(w2_shuffled).view(torch.float8_e4m3fn) + w13_scale_out = torch.stack(w13_scale_shuffled).reshape(w13_scale.shape) + w2_scale_out = torch.stack(w2_scale_shuffled).reshape(w2_scale.shape) + + return w13_out, w2_out, w13_scale_out, w2_scale_out + + def prepare_fp8_moe_layer_for_fi( layer: torch.nn.Module, w13: torch.Tensor, @@ -314,7 +389,7 @@ def prepare_fp8_moe_layer_for_fi( w2_scale: torch.Tensor, w2_input_scale: torch.Tensor | None, is_trtllm: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Convert Fp8 MoE weights to flashinfer kernel format @@ -329,10 +404,33 @@ def prepare_fp8_moe_layer_for_fi( block_quant = ( hasattr(layer, "weight_block_size") and layer.weight_block_size is not None ) + is_mxfp8 = block_quant and w13_scale.dtype == torch.uint8 + is_gated = layer.activation.is_gated + + # MXFP8 TRT-LLM requires W31 swap + reorder + shuffle. + if is_mxfp8 and is_trtllm: + # FlashInfer TRT-LLM SwiGLU expects [up; gate] but vLLM stores + # [gate; up]. Swap both weights and scales before interleaving. + if layer.moe_config.is_act_and_mul: + w13 = swap_w13_to_w31(w13) + # Scales may be 2D [E, flat] from _quantize_mxfp8_moe_weight; + # reshape to 3D so swap_w13_to_w31 can flip the two halves, + # then flatten back. + if w13_scale.ndim == 2: + num_rows = w13.shape[1] # 2 * intermediate_size + w13_scale = w13_scale.reshape(w13_scale.shape[0], num_rows, -1) + w13_scale = swap_w13_to_w31(w13_scale) + w13_scale = w13_scale.reshape(w13_scale.shape[0], -1) + else: + w13_scale = swap_w13_to_w31(w13_scale) + + w13, w2, w13_scale, w2_scale = _shuffle_mxfp8_moe_weights( + w13, w2, w13_scale, w2_scale, is_gated + ) + return w13, w2, w13_scale, w2_scale # Some FI MoE kernels require internal alignment of 16 # for the gate-up proj. Pad the weights to respect this. - is_gated = layer.activation.is_gated if not block_quant: min_alignment = 16 if is_gated else 128 w13, w2, new_intermediate = align_fp8_moe_weights_for_fi( @@ -369,4 +467,4 @@ def prepare_fp8_moe_layer_for_fi( w13_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE) w2_scale.clamp_(min=_FI_CUTLASS_MIN_BLOCK_SCALE) - return w13, w2, w13_scale + return w13, w2, w13_scale, w2_scale diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 12a1799d1..1170a2d3a 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -149,6 +149,12 @@ kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) kStatic128BlockScale = ScaleDesc(torch.float32, True, GroupShape(128, 128)) kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True) +kMxfp8StaticScale = ScaleDesc(torch.uint8, True, GroupShape(1, 32)) +kMxfp8Static = QuantKey(FP8_DTYPE, kMxfp8StaticScale, symmetric=True) + +kMxfp8DynamicScale = ScaleDesc(torch.uint8, False, GroupShape(1, 32)) +kMxfp8Dynamic = QuantKey(FP8_DTYPE, kMxfp8DynamicScale, symmetric=True) + kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)