diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py index c9b2b85f0..73502932d 100644 --- a/tests/kernels/moe/test_ocp_mx_moe.py +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -20,6 +20,8 @@ TRTLLM_GEN_MXFP4_AVAILABLE = ( current_platform.is_cuda() and current_platform.is_device_capability_family(100) ) +TRTLLM_GEN_MXFP8_AVAILABLE = TRTLLM_GEN_MXFP4_AVAILABLE + HOPPER_MXFP4_BF16_AVAILABLE = ( current_platform.is_cuda() and current_platform.is_device_capability(90) @@ -34,9 +36,15 @@ if TRTLLM_GEN_MXFP4_AVAILABLE: shuffle_matrix_a, shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe, + trtllm_fp8_block_scale_moe, ) from flashinfer.fp4_quantization import nvfp4_block_scale_interleave - from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache + +if TRTLLM_GEN_MXFP8_AVAILABLE: + from flashinfer.fused_moe.core import ( + Fp8QuantizationType, + get_w2_permute_indices_with_cache, + ) @dataclass @@ -160,6 +168,7 @@ def reference_moe( beta, limit, act_type, + is_gated, ): # renormalize routing experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True) @@ -170,7 +179,12 @@ def reference_moe( mlp1_weight = w13[expert_indices, ...] mlp1_bias = bias13[expert_indices, ...] t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias - t = swiglu(t, alpha=alpha, beta=beta, limit=limit) + if is_gated: + t = swiglu(t, alpha=alpha, beta=beta, limit=limit) + else: + # RELU2_NO_MUL: relu(x)^2 + t = torch.relu(t) + t = t * t if act_type == "mxfp8": t_quantized, t_scale = mxfp8_quantize( @@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe( beta, limit, act_type, + is_gated=True, ) ref_result[start_idx:end_idx].copy_(chunk_result) @@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( beta, limit, "bf16", + is_gated=True, ) from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe @@ -890,6 +906,7 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( beta, limit, "mxfp8", + is_gated=True, ) # Prepare inputs for FlashInfer CUTLASS fused MoE @@ -965,3 +982,169 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( # Allow some mismatch due to MXFP4 quantization check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) + + +@pytest.mark.parametrize("topk", [1, 4]) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("num_tokens", [1, 128]) +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) +@pytest.mark.parametrize("is_gated", [True], ids=["gated"]) +@pytest.mark.skipif( + not TRTLLM_GEN_MXFP8_AVAILABLE, + reason="nvidia gpu and compute capability sm100 is required for this test", +) +def test_trtllm_gen_mxfp8_block_scale_moe( + topk: int, + num_experts: int, + num_tokens: int, + intermediate_size: int, + hidden_size: int, + is_gated: bool, +): + torch.manual_seed(42) + device = "cuda:0" + + inter_size = intermediate_size * (2 if is_gated else 1) + + hidden_states = ( + torch.randn(num_tokens, hidden_size, device=device, dtype=torch.bfloat16) / 20 + ) + w13 = ( + torch.randn( + num_experts, + inter_size, + hidden_size, + device=device, + dtype=torch.bfloat16, + ) + / 20 + ) + w2 = ( + torch.randn( + num_experts, + hidden_size, + intermediate_size, + device=device, + dtype=torch.bfloat16, + ) + / 20 + ) + router_logits = torch.rand( + num_tokens, num_experts, dtype=torch.float32, device=device + ) + router_logits_kernel = router_logits.to(torch.bfloat16) + + # Quantize weights to MXFP8 and normalize scales to [E, M, K//32]. + w13_q, w13_scale = mxfp8_quantize(w13, is_sf_swizzled_layout=False) + w2_q, w2_scale = mxfp8_quantize(w2, is_sf_swizzled_layout=False) + if w13_scale.ndim == 1: + w13_scale = w13_scale.view( + num_experts, + inter_size, + hidden_size // 32, + ) + if w2_scale.ndim == 1: + w2_scale = w2_scale.view(num_experts, hidden_size, intermediate_size // 32) + + # Quantize activations to MXFP8. + hidden_states_q, hidden_states_scale = mxfp8_quantize( + hidden_states, is_sf_swizzled_layout=False + ) + if hidden_states_scale.ndim == 1: + hidden_states_scale = hidden_states_scale.view(num_tokens, hidden_size // 32) + + # Reference output using dequantized tensors + MXFP8 intermediate quantization. + w13_ref = mxfp8_dequantize(w13_q, w13_scale).to(torch.float32) + w2_ref = mxfp8_dequantize(w2_q, w2_scale).to(torch.float32) + hidden_states_ref = mxfp8_dequantize(hidden_states_q, hidden_states_scale).to( + torch.float32 + ) + bias13 = torch.zeros( + num_experts, + intermediate_size * (2 if is_gated else 1), + device=device, + ) + bias2 = torch.zeros(num_experts, hidden_size, device=device) + ref = reference_moe( + router_logits_kernel.to(torch.float32), + topk, + num_experts, + hidden_states_ref, + w13_ref, + bias13, + w2_ref, + bias2, + alpha=1.0, + beta=0.0, + limit=None, + act_type="mxfp8", + is_gated=is_gated, + ) + + # Shuffle weights/scales with the same indexed layout used by TRTLLM kernels. + epilogue_tile_m = 128 + gemm1_weights_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_shuffled = [] + gemm2_scales_shuffled = [] + for i in range(num_experts): + w13_rows = intermediate_size * (2 if is_gated else 1) + w13_interleaved = w13_q[i].clone().reshape(w13_rows, -1) + w13_scale_interleaved = w13_scale[i].clone().reshape(w13_rows, -1) + if is_gated: + w13_interleaved = reorder_rows_for_gated_act_gemm(w13_interleaved) + w13_scale_interleaved = reorder_rows_for_gated_act_gemm( + w13_scale_interleaved + ) + gemm1_weights_shuffled.append( + shuffle_matrix_a(w13_interleaved.view(torch.uint8), epilogue_tile_m) + .contiguous() + .view(w13_q.dtype) + ) + gemm2_weights_shuffled.append( + shuffle_matrix_a(w2_q[i].view(torch.uint8), epilogue_tile_m) + .contiguous() + .view(w2_q.dtype) + ) + + gemm1_scales_shuffled.append( + shuffle_matrix_sf_a( + w13_scale_interleaved.view(torch.uint8).reshape(w13_rows, -1), + epilogue_tile_m, + ) + .contiguous() + .view(w13_scale.dtype) + ) + gemm2_scales_shuffled.append( + shuffle_matrix_sf_a( + w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), epilogue_tile_m + ) + .contiguous() + .view(w2_scale.dtype) + ) + + out = trtllm_fp8_block_scale_moe( + routing_logits=router_logits_kernel, + routing_bias=None, + hidden_states=hidden_states_q, + hidden_states_scale=hidden_states_scale, + gemm1_weights=torch.stack(gemm1_weights_shuffled), + gemm1_weights_scale=torch.stack(gemm1_scales_shuffled), + gemm2_weights=torch.stack(gemm2_weights_shuffled), + gemm2_weights_scale=torch.stack(gemm2_scales_shuffled), + num_experts=num_experts, + top_k=topk, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, # renormalize routing + use_shuffled_weight=True, + weight_layout=0, # MajorK + fp8_quantization_type=Fp8QuantizationType.MxFp8, + ) + + # Block-scale MXFP8 kernels are approximate; require majority close. + check_accuracy(ref, out, atol=0.1, rtol=0.85, percent=0.8) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 620047709..92b0f0e0d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1204,17 +1204,26 @@ class FusedMoE(CustomOp): # Determine per-tensor weight scale patterns based on variant # Use the dedicated method instead of brittle string matching uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern() + quant_method = getattr(param, "quant_method", None) # Call _load_per_tensor_weight_scale() to load per-tensor (scalar) # weights scales. # Input scales are always per-tensor. # Weight scales: FP4 uses "weight_scale_2" and FP8 uses # "weight_scale" for per-tensor scales. + # NOTE: ModelOpt MXFP8 MoE uses block scales in weight_scale + # tensors (quant_method=BLOCK), so those must not be treated + # as per-tensor scalars here. + is_block_weight_scale = ( + "weight_scale" in weight_name + and quant_method == FusedMoeWeightScaleSupported.BLOCK.value + ) is_per_tensor = ( "weight_scale_2" in weight_name if uses_weight_scale_2 else "weight_scale" in weight_name ) or "input_scale" in weight_name + is_per_tensor = is_per_tensor and not is_block_weight_scale if is_per_tensor: self._load_per_tensor_weight_scale( shard_id=shard_id, diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py new file mode 100644 index 000000000..49406ba93 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig + +logger = init_logger(__name__) + + +class MxFp8MoeBackend(Enum): + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + + +def select_mxfp8_moe_backend( + config: FusedMoEConfig, +) -> MxFp8MoeBackend: + if config.is_lora_enabled: + raise NotImplementedError("LoRA is not supported for MXFP8 MoE.") + + AVAILABLE_BACKENDS = [ + MxFp8MoeBackend.FLASHINFER_TRTLLM, + ] + + runner_backend = config.moe_backend + if runner_backend != "auto": + mapping = { + "flashinfer_trtllm": MxFp8MoeBackend.FLASHINFER_TRTLLM, + } + if backend := mapping.get(runner_backend): + logger.info_once( + "Using '%s' MxFp8 MoE backend (user-requested).", + backend.value, + ) + return backend + raise ValueError( + f"moe_backend='{runner_backend}' is not supported for MXFP8 MoE. " + f"Expected one of {list(mapping.keys())}." + ) + + # Auto-select: only one backend available for now. + backend = AVAILABLE_BACKENDS[0] + logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value) + return backend diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index f167e2134..977612313 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -9,17 +9,19 @@ from torch.nn.parameter import Parameter import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger -from vllm.model_executor.kernels.linear import ( - init_fp8_linear_kernel, -) +from vllm.model_executor.kernels.linear import init_fp8_linear_kernel from vllm.model_executor.layers.attention import Attention, MLAAttention +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, ) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, - FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( @@ -28,6 +30,10 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( make_fp8_moe_quant_config, select_fp8_moe_backend, ) +from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import ( + MxFp8MoeBackend, + select_mxfp8_moe_backend, +) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, @@ -46,6 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + swap_w13_to_w31, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, process_fp8_input_tensor_strategy_moe, @@ -60,6 +69,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( MXFP8_VALUE_DTYPE, Mxfp8LinearBackend, Mxfp8LinearOp, + mxfp8_e4m3_quantize, swizzle_mxfp8_scale, ) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( @@ -86,7 +96,8 @@ from vllm.model_executor.parameter import ( ModelWeightParameter, PerTensorScaleParameter, ) -from vllm.model_executor.utils import replace_parameter +from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -1487,17 +1498,6 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase): # MXFP8 hardware acceleration requires Blackwell (SM100) or newer return 100 - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> "QuantizeMethodBase | None": - # MXFP8 does not yet support MoE models - if isinstance(layer, FusedMoE): - raise NotImplementedError( - "MXFP8 quantization does not yet support MoE models. " - "Please use FP8 or NVFP4 quantization for MoE models." - ) - return super().get_quant_method(layer, prefix) - @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant @@ -1699,8 +1699,351 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): ) +class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + """FlashInfer TRTLLM MXFP8 block-scale MoE for ModelOpt checkpoints.""" + + def __init__( + self, + quant_config: ModelOptMxFp8Config, + moe_config: FusedMoEConfig, + ) -> None: + super().__init__(moe_config) + self.quant_config = quant_config + assert self.quant_config.is_checkpoint_mxfp8_serialized + + # Select MXFP8 MoE backend + self.mxfp8_backend = select_mxfp8_moe_backend(self.moe) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.orig_dtype = params_dtype + + if hidden_size % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + f"MXFP8 MoE requires hidden_size divisible by {MXFP8_BLOCK_SIZE}, " + f"got {hidden_size}." + ) + if intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + "MXFP8 MoE requires intermediate_size_per_partition divisible by " + f"{MXFP8_BLOCK_SIZE}, got {intermediate_size_per_partition}." + ) + + layer.num_experts = num_experts + weight_loader = extra_weight_attrs.get("weight_loader") + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + # GEMM 1 weights: [E, (2I or I), H] + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size, + dtype=MXFP8_VALUE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 weights: [E, H, I] + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=MXFP8_VALUE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + # Per-block (K=32) E8M0 scales. + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size // MXFP8_BLOCK_SIZE, + dtype=MXFP8_SCALE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // MXFP8_BLOCK_SIZE, + dtype=MXFP8_SCALE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Ensure the generic MoE weight-loader treats these as block scales. + set_weight_attrs( + layer.w13_weight_scale, + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}, + ) + set_weight_attrs( + layer.w2_weight_scale, + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}, + ) + + @staticmethod + def _check_weight_dtypes(layer: torch.nn.Module) -> None: + """Validate weight and scale dtypes before processing.""" + expected = { + "w13_weight": MXFP8_VALUE_DTYPE, + "w2_weight": MXFP8_VALUE_DTYPE, + "w13_weight_scale": MXFP8_SCALE_DTYPE, + "w2_weight_scale": MXFP8_SCALE_DTYPE, + } + for name, expected_dtype in expected.items(): + actual = getattr(layer, name).dtype + if actual != expected_dtype: + raise ValueError( + f"Expected {name} dtype {expected_dtype}, got {actual}." + ) + + def _shuffle_weights_for_trtllm(self, layer: torch.nn.Module) -> None: + """Shuffle weights and scales into FlashInfer TRTLLM MXFP8 layout.""" + from flashinfer import ( + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) + + epilogue_tile_m = 128 + num_experts = layer.w13_weight.shape[0] + is_gated = self.moe.is_act_and_mul + intermediate_size_factor = 2 if is_gated else 1 + + w13_weight = layer.w13_weight.data + w13_scale = layer.w13_weight_scale.data + if is_gated: + # FI TRTLLM gated kernels use W31 ordering. Model checkpoints store + # gated projection as W13, so convert once before shuffling. + w13_weight = swap_w13_to_w31(w13_weight) + w13_scale = swap_w13_to_w31(w13_scale) + + w13_weight_shuffled = [] + w2_weight_shuffled = [] + w13_scale_shuffled = [] + w2_scale_shuffled = [] + for i in range(num_experts): + w13_i = w13_weight[i].reshape( + intermediate_size_factor * layer.intermediate_size_per_partition, -1 + ) + w13_sf_i = w13_scale[i].reshape( + intermediate_size_factor * layer.intermediate_size_per_partition, -1 + ) + if is_gated: + # Reorder rows for gated activation layout expected by TRTLLM. + w13_i = reorder_rows_for_gated_act_gemm(w13_i.clone()) + w13_sf_i = reorder_rows_for_gated_act_gemm(w13_sf_i.clone()) + + w13_shuffled_i = shuffle_matrix_a(w13_i.view(torch.uint8), epilogue_tile_m) + w2_shuffled_i = shuffle_matrix_a( + layer.w2_weight.data[i].view(torch.uint8), epilogue_tile_m + ) + w13_weight_shuffled.append( + w13_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE) + ) + w2_weight_shuffled.append( + w2_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE) + ) + w13_sf_shuffled_i = shuffle_matrix_sf_a( + w13_sf_i.view(torch.uint8).reshape( + intermediate_size_factor * layer.intermediate_size_per_partition, + -1, + ), + epilogue_tile_m, + ) + w2_sf_shuffled_i = shuffle_matrix_sf_a( + layer.w2_weight_scale.data[i] + .view(torch.uint8) + .reshape(layer.hidden_size, -1), + epilogue_tile_m, + ) + w13_scale_shuffled.append( + w13_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE) + ) + w2_scale_shuffled.append( + w2_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE) + ) + + replace_parameter( + layer, "w13_weight", torch.stack(w13_weight_shuffled).contiguous() + ) + replace_parameter( + layer, "w2_weight", torch.stack(w2_weight_shuffled).contiguous() + ) + replace_parameter( + layer, + "w13_weight_scale", + torch.stack(w13_scale_shuffled).contiguous(), + ) + replace_parameter( + layer, + "w2_weight_scale", + torch.stack(w2_scale_shuffled).contiguous(), + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + self._check_weight_dtypes(layer) + self._shuffle_weights_for_trtllm(layer) + layer._already_called_process_weights_after_loading = True + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, + layer: torch.nn.Module, + ) -> mk.FusedMoEExpertsModular: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + # TRTLLM MXFP8 path is monolithic and does not use modular kernel config. + return None + + @property + def is_monolithic(self) -> bool: + return self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM + + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + from flashinfer.fused_moe.core import ( + ActivationType, + Fp8QuantizationType, + ) + + assert self.mxfp8_backend == MxFp8MoeBackend.FLASHINFER_TRTLLM + + if layer.enable_eplb: + raise NotImplementedError( + "EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend." + ) + + supported_activations = [MoEActivation.SILU] + if layer.activation not in supported_activations: + raise NotImplementedError( + "FlashInfer TRTLLM MXFP8 MoE supports only " + f"{supported_activations}, got {layer.activation}." + ) + + # Map vLLM MoEActivation to FlashInfer ActivationType. + activation_map = { + MoEActivation.SILU: ActivationType.Swiglu, + MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, + } + fi_activation_type: ActivationType = activation_map[layer.activation] + + # DeepSeekV3 routing requires float32 logits; others expect bfloat16. + if layer.routing_method_type == RoutingMethodType.DeepSeekV3: + assert router_logits.dtype == torch.float32, ( + "DeepSeekV3 routing requires float32 router_logits, " + f"got {router_logits.dtype}." + ) + else: + router_logits = router_logits.to(torch.bfloat16) + + # Treat 0 as "unset" for compatibility with ungrouped routing configs. + n_group = layer.num_expert_group or None + topk_group = layer.topk_group or None + + hidden_states_mxfp8, hidden_states_scale = mxfp8_e4m3_quantize( + x, + is_sf_swizzled_layout=False, + ) + + kwargs: dict = dict( + routing_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + hidden_states=hidden_states_mxfp8, + hidden_states_scale=hidden_states_scale, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale, + num_experts=layer.global_num_experts, + top_k=layer.top_k, + # Keep Optional semantics: FlashInfer expects None for non-grouped + # routing (e.g. Qwen3 Renormalize), not 0. + n_group=n_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routed_scaling_factor=layer.routed_scaling_factor, + routing_method_type=layer.routing_method_type, + use_shuffled_weight=True, + weight_layout=0, + fp8_quantization_type=Fp8QuantizationType.MxFp8, + ) + + if fi_activation_type != ActivationType.Swiglu: + raise NotImplementedError( + "FlashInfer TRTLLM MXFP8 MoE supports only Swiglu activation, " + f"got {fi_activation_type}." + ) + + return flashinfer_trtllm_fp8_block_scale_moe(**kwargs) + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert not self.is_monolithic + raise NotImplementedError( + "Non-monolithic MXFP8 MoE path is not yet implemented." + ) + + # Register the method classes for ModelOptMxFp8Config ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod +ModelOptMxFp8Config.FusedMoEMethodCls = ModelOptMxFp8FusedMoE ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod