From 7d98f09b1cc06ba485d8c79e1f2f5b0062e2e2f6 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 3 Feb 2026 16:26:51 -0500 Subject: [PATCH] cherry pick Signed-off-by: Robert Shaw --- .../layers/fused_moe/flashinfer_trtllm_moe.py | 35 ++++++++++++++++--- .../layers/fused_moe/oracle/fp8.py | 8 ++--- .../compressed_tensors_moe.py | 15 ++------ .../model_executor/layers/quantization/fp8.py | 15 ++------ vllm/model_executor/models/minimax_m2.py | 1 + 5 files changed, 42 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 0d7473aaf..31351fbfe 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -86,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo return not moe_parallel_config.enable_eplb -def is_supported_config_trtllm( +def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, +) -> bool: + """ + The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. + Only DeepSeekV3 routing supports float32 router_logits (which is converted + internally in the kernel). + """ + if router_logits_dtype == torch.float32: + # Only DeepSeekV3 routing handles float32 logits + # https://github.com/flashinfer-ai/flashinfer/issues/2469 + return routing_method == RoutingMethodType.DeepSeekV3 + return True + + +def is_supported_config_trtllm_fp8( moe_config: FusedMoEConfig, weight_key: QuantKey | None, activation_key: QuantKey | None, @@ -115,13 +131,17 @@ def is_supported_config_trtllm( return False, _make_reason("routing method") elif activation_format != mk.FusedMoEActivationFormat.Standard: return False, _make_reason("activation format") + elif not _supports_router_logits_dtype( + moe_config.router_logits_dtype, moe_config.routing_method + ): + return False, _make_reason("float32 router_logits with non-DeepSeekV3 routing") return True, None def flashinfer_fused_moe_blockscale_fp8( routing_logits: torch.Tensor, - routing_bias: torch.Tensor, + routing_bias: torch.Tensor | None, x: torch.Tensor, w13_weight: torch.Tensor, w13_weight_scale_inv: torch.Tensor, @@ -135,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8( expert_offset: int, local_num_experts: int, block_shape: list[int], - routing_method_type: int = int(RoutingMethodType.DeepSeekV3), + routing_method_type: int, routed_scaling: float | None = 1.0, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe @@ -148,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8( # Routing kernel expects #experts <= #threads 512 assert global_num_experts <= 512 + # The DeepSeekV3 routing method requires float32 router logits. + if routing_method_type == RoutingMethodType.DeepSeekV3: + routing_logits = routing_logits.to(torch.float32) + + if routing_bias is not None: + routing_bias = routing_bias.to(x.dtype) + a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() @@ -175,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8( def flashinfer_fused_moe_blockscale_fp8_fake( routing_logits: torch.Tensor, - routing_bias: torch.Tensor, + routing_bias: torch.Tensor | None, x: torch.Tensor, w13_weight: torch.Tensor, w13_weight_scale_inv: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index cdf2d291b..c89dc6a86 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( - is_supported_config_trtllm, + is_supported_config_trtllm_fp8, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, @@ -212,7 +212,7 @@ def select_fp8_moe_backend( if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: backend = Fp8MoeBackend.FLASHINFER_TRTLLM - supported, reason = is_supported_config_trtllm( + supported, reason = is_supported_config_trtllm_fp8( config, weight_key, activation_key, activation_format ) if supported: @@ -239,7 +239,7 @@ def select_fp8_moe_backend( ]: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: k_cls = None - supported, reason = is_supported_config_trtllm( + supported, reason = is_supported_config_trtllm_fp8( config, weight_key, activation_key, @@ -308,7 +308,7 @@ def select_fp8_moe_backend( for backend in AVAILABLE_BACKENDS: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: k_cls = None - supported, reason = is_supported_config_trtllm( + supported, reason = is_supported_config_trtllm_fp8( config, weight_key, activation_key, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ec120cab4..a8b064584 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, - RoutingMethodType, int4_w4a16_moe_quant_config, int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, @@ -1072,17 +1071,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - e_score_correction_bias = ( - layer.e_score_correction_bias.to(x.dtype) - if layer.e_score_correction_bias is not None - else None - ) - routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits, - routing_bias=e_score_correction_bias, + routing_logits=router_logits, + routing_bias=layer.e_score_correction_bias, x=x, w13_weight=layer.w13_weight, w13_weight_scale_inv=layer.w13_weight_scale, @@ -1096,7 +1087,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, - routing_method_type=routing_method_type, + routing_method_type=layer.routing_method_type, routed_scaling=layer.routed_scaling_factor, ) else: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fe59022cb..f8e46ef5e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import ( ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, - RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( @@ -990,17 +989,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - e_score_correction_bias = ( - layer.e_score_correction_bias.to(x.dtype) - if layer.e_score_correction_bias is not None - else None - ) - routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits, - routing_bias=e_score_correction_bias, + routing_logits=router_logits, + routing_bias=layer.e_score_correction_bias, x=x, w13_weight=layer.w13_weight, w13_weight_scale_inv=layer.w13_weight_scale_inv, @@ -1014,7 +1005,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, - routing_method_type=routing_method_type, + routing_method_type=layer.routing_method_type, routed_scaling=layer.routed_scaling_factor, ) else: diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 95b036ac2..3bc13c7fd 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module): renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", + router_logits_dtype=torch.float32, ) self.gate = ReplicatedLinear(