[Bugfix] FlashInfer MXINT4 MoE crashes, missing do_finalize (#39315)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benjamin Chislett
2026-04-08 20:36:33 -04:00
committed by GitHub
parent ba4a78eb5d
commit 8332078cfd
2 changed files with 98 additions and 1 deletions

View File

@@ -269,3 +269,96 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group
# Note: Different quantization schemes (UINT4b8 vs signed MXINT4) cause
# some differences
torch.testing.assert_close(marlin_output, trtllm_output, atol=0.3, rtol=6.0)
@pytest.mark.skipif(not TRTLLM_GEN_AVAILABLE, reason="Skip for non SM100")
@pytest.mark.parametrize("m", [1, 33])
@pytest.mark.parametrize("n", [7168])
@pytest.mark.parametrize("k", [512])
@pytest.mark.parametrize("e", [384])
@pytest.mark.parametrize("topk", [8])
@torch.inference_mode()
def test_flashinfer_trtllm_mxint4_moe_wrapper(m, n, k, e, topk):
"""Test that the flashinfer_trtllm_mxint4_moe wrapper matches the raw
trtllm_mxint4_block_scale_moe kernel call."""
pytest.importorskip("flashinfer")
from flashinfer import RoutingMethodType
from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
flashinfer_trtllm_mxint4_moe,
)
torch.cuda.manual_seed(0)
dtype = torch.bfloat16
a = torch.randn((m, k), device="cuda", dtype=dtype) * 0.5
router_logits = torch.randn((m, e), device="cuda", dtype=torch.float32) * 1.5
routing_bias = torch.randn(e, device="cuda", dtype=torch.float32) * 0.8
std_w1 = (2.0 / (k + 2 * n)) ** 0.5
std_w2 = (2.0 / (n + k)) ** 0.5
w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) * std_w1
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) * std_w2
w1_int4, w1_scales = mxint4_quantize_moe_weights(w1_bf16)
w2_int4, w2_scales = mxint4_quantize_moe_weights(w2_bf16)
prepared = prepare_static_weights_for_trtllm_mxint4_moe(
gemm1_weights=w1_int4,
gemm1_scales=w1_scales,
gemm2_weights=w2_int4,
gemm2_scales=w2_scales,
)
# Raw kernel call (reference)
raw_out = trtllm_mxint4_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_bias=routing_bias.to(torch.bfloat16),
hidden_states=a,
gemm1_weights=prepared["gemm1_weights"].data,
gemm1_weights_scale=prepared["gemm1_scales"].data,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=prepared["gemm2_weights"].data,
gemm2_weights_scale=prepared["gemm2_scales"].data,
num_experts=e,
top_k=topk,
n_group=1,
topk_group=1,
intermediate_size=n,
local_expert_offset=0,
local_num_experts=e,
routed_scaling_factor=None,
routing_method_type=RoutingMethodType.DeepSeekV3,
enable_pdl=None,
output=None,
tune_max_num_tokens=8192,
)
if not isinstance(raw_out, torch.Tensor):
raw_out = raw_out[0]
raw_out = raw_out.to(dtype)
# Wrapper call
wrapper_out = flashinfer_trtllm_mxint4_moe(
x=a,
router_logits=router_logits,
w13_weight_packed=prepared["gemm1_weights"],
w13_weight_scale=prepared["gemm1_scales"],
w2_weight_packed=prepared["gemm2_weights"],
w2_weight_scale=prepared["gemm2_scales"],
global_num_experts=e,
top_k=topk,
intermediate_size_per_partition=n,
local_num_experts=e,
ep_rank=0,
num_expert_group=1,
topk_group=1,
e_score_correction_bias=routing_bias,
routing_method_type=RoutingMethodType.DeepSeekV3,
)
assert wrapper_out.shape == (m, k)
assert wrapper_out.dtype == dtype
torch.testing.assert_close(wrapper_out, raw_out, atol=0.0, rtol=0.0)