From 8332078cfdbd5e44e527893b695e79052d008172 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Wed, 8 Apr 2026 20:36:33 -0400 Subject: [PATCH] [Bugfix] FlashInfer MXINT4 MoE crashes, missing do_finalize (#39315) Signed-off-by: Benjamin Chislett Signed-off-by: Benjamin Chislett Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../moe/test_marlin_vs_trtllm_mxint4.py | 93 +++++++++++++++++++ .../utils/flashinfer_mxint4_moe.py | 6 +- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py b/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py index aaf255ca8..0ce3d165d 100644 --- a/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py +++ b/tests/kernels/moe/test_marlin_vs_trtllm_mxint4.py @@ -269,3 +269,96 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group # Note: Different quantization schemes (UINT4b8 vs signed MXINT4) cause # some differences torch.testing.assert_close(marlin_output, trtllm_output, atol=0.3, rtol=6.0) + + +@pytest.mark.skipif(not TRTLLM_GEN_AVAILABLE, reason="Skip for non SM100") +@pytest.mark.parametrize("m", [1, 33]) +@pytest.mark.parametrize("n", [7168]) +@pytest.mark.parametrize("k", [512]) +@pytest.mark.parametrize("e", [384]) +@pytest.mark.parametrize("topk", [8]) +@torch.inference_mode() +def test_flashinfer_trtllm_mxint4_moe_wrapper(m, n, k, e, topk): + """Test that the flashinfer_trtllm_mxint4_moe wrapper matches the raw + trtllm_mxint4_block_scale_moe kernel call.""" + pytest.importorskip("flashinfer") + from flashinfer import RoutingMethodType + from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe + + from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( + flashinfer_trtllm_mxint4_moe, + ) + + torch.cuda.manual_seed(0) + dtype = torch.bfloat16 + + a = torch.randn((m, k), device="cuda", dtype=dtype) * 0.5 + router_logits = torch.randn((m, e), device="cuda", dtype=torch.float32) * 1.5 + routing_bias = torch.randn(e, device="cuda", dtype=torch.float32) * 0.8 + + std_w1 = (2.0 / (k + 2 * n)) ** 0.5 + std_w2 = (2.0 / (n + k)) ** 0.5 + w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) * std_w1 + w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) * std_w2 + + w1_int4, w1_scales = mxint4_quantize_moe_weights(w1_bf16) + w2_int4, w2_scales = mxint4_quantize_moe_weights(w2_bf16) + + prepared = prepare_static_weights_for_trtllm_mxint4_moe( + gemm1_weights=w1_int4, + gemm1_scales=w1_scales, + gemm2_weights=w2_int4, + gemm2_scales=w2_scales, + ) + + # Raw kernel call (reference) + raw_out = trtllm_mxint4_block_scale_moe( + routing_logits=router_logits.to(torch.float32), + routing_bias=routing_bias.to(torch.bfloat16), + hidden_states=a, + gemm1_weights=prepared["gemm1_weights"].data, + gemm1_weights_scale=prepared["gemm1_scales"].data, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=prepared["gemm2_weights"].data, + gemm2_weights_scale=prepared["gemm2_scales"].data, + num_experts=e, + top_k=topk, + n_group=1, + topk_group=1, + intermediate_size=n, + local_expert_offset=0, + local_num_experts=e, + routed_scaling_factor=None, + routing_method_type=RoutingMethodType.DeepSeekV3, + enable_pdl=None, + output=None, + tune_max_num_tokens=8192, + ) + if not isinstance(raw_out, torch.Tensor): + raw_out = raw_out[0] + raw_out = raw_out.to(dtype) + + # Wrapper call + wrapper_out = flashinfer_trtllm_mxint4_moe( + x=a, + router_logits=router_logits, + w13_weight_packed=prepared["gemm1_weights"], + w13_weight_scale=prepared["gemm1_scales"], + w2_weight_packed=prepared["gemm2_weights"], + w2_weight_scale=prepared["gemm2_scales"], + global_num_experts=e, + top_k=topk, + intermediate_size_per_partition=n, + local_num_experts=e, + ep_rank=0, + num_expert_group=1, + topk_group=1, + e_score_correction_bias=routing_bias, + routing_method_type=RoutingMethodType.DeepSeekV3, + ) + + assert wrapper_out.shape == (m, k) + assert wrapper_out.dtype == dtype + torch.testing.assert_close(wrapper_out, raw_out, atol=0.0, rtol=0.0) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py index 98a3d1e12..4e08a73a6 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py @@ -259,8 +259,12 @@ def flashinfer_trtllm_mxint4_moe( routed_scaling_factor=None, routing_method_type=routing_method_type, enable_pdl=None, + do_finalize=True, output=None, tune_max_num_tokens=8192, - ).to(x.dtype) + ) + if isinstance(out, (tuple, list)): + out = out[0] + out = out.to(x.dtype) return out