diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index d524b5667..5ecef3dbd 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -398,3 +398,80 @@ def test_convert_moe_weights_to_flashinfer_trtllm_block_layout( assert w13_converted.shape[0] == num_experts assert w2_converted.shape[0] == num_experts + + +def test_flashinfer_blockscale_fp8_none_expert_group(monkeypatch): + """Test that flashinfer_fused_moe_blockscale_fp8 handles num_expert_group=None. + + Regression test for https://github.com/vllm-project/vllm/issues/34477 + MiniMax-M2.1 uses sigmoid scoring with e_score_correction_bias but no + grouped top-k, resulting in num_expert_group=None. This triggered a crash + in the flashinfer kernel when DeepSeekV3 routing was selected. + """ + if not current_platform.has_device_capability(100): + pytest.skip("Test requires SM >= 100 (Blackwell)") + + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + from tests.kernels.quant_utils import native_per_token_group_quant_fp8 + + set_random_seed(7) + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") + + e = 16 # num_experts (must be divisible by 4) + topk = 6 # top_k > 1 triggers DeepSeekV3 routing with sigmoid + m, n, k = 10, 4096, 5120 + block_shape = [128, 128] + block_k = block_shape[1] + + with set_current_vllm_config(vllm_config): + # Create BF16 hidden states + x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 + + # Create FP8 block-scale quantized weights + w13_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10 + w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10 + + # Quantize weights per-block to FP8 + w13_fp8_list, w13_scale_list = [], [] + w2_fp8_list, w2_scale_list = [], [] + for i in range(e): + wq, ws = native_per_token_group_quant_fp8(w13_bf16[i], block_k) + w13_fp8_list.append(wq) + w13_scale_list.append(ws) + + wq, ws = native_per_token_group_quant_fp8(w2_bf16[i], block_k) + w2_fp8_list.append(wq) + w2_scale_list.append(ws) + + w13_fp8 = torch.stack(w13_fp8_list) + w13_scale = torch.stack(w13_scale_list) + w2_fp8 = torch.stack(w2_fp8_list) + w2_scale = torch.stack(w2_scale_list) + + # DeepSeekV3 routing uses float32 logits + optional bias + routing_logits = torch.randn((m, e), device="cuda", dtype=torch.float32) + routing_bias = torch.randn(e, device="cuda", dtype=torch.float32) + + # This should NOT crash with num_expert_group=None + output = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=routing_logits, + routing_bias=routing_bias, + x=x, + w13_weight=w13_fp8, + w13_weight_scale_inv=w13_scale, + w2_weight=w2_fp8, + w2_weight_scale_inv=w2_scale, + global_num_experts=e, + top_k=topk, + num_expert_group=None, + topk_group=None, + intermediate_size=n, + expert_offset=0, + local_num_experts=e, + block_shape=block_shape, + routing_method_type=RoutingMethodType.DeepSeekV3, + routed_scaling=1.0, + ) + + assert output is not None + assert output.shape == (m, k) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index b2d571dd8..d86896e54 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -201,6 +201,7 @@ def flashinfer_fused_moe_blockscale_fp8( ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + num_expert_group = num_expert_group if num_expert_group is not None else 0 topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts assert top_k <= 10