[Bugfix] FlashInfer MXINT4 MoE crashes, missing do_finalize (#39315)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
ba4a78eb5d
commit
8332078cfd
@@ -269,3 +269,96 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group
|
|||||||
# Note: Different quantization schemes (UINT4b8 vs signed MXINT4) cause
|
# Note: Different quantization schemes (UINT4b8 vs signed MXINT4) cause
|
||||||
# some differences
|
# some differences
|
||||||
torch.testing.assert_close(marlin_output, trtllm_output, atol=0.3, rtol=6.0)
|
torch.testing.assert_close(marlin_output, trtllm_output, atol=0.3, rtol=6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not TRTLLM_GEN_AVAILABLE, reason="Skip for non SM100")
|
||||||
|
@pytest.mark.parametrize("m", [1, 33])
|
||||||
|
@pytest.mark.parametrize("n", [7168])
|
||||||
|
@pytest.mark.parametrize("k", [512])
|
||||||
|
@pytest.mark.parametrize("e", [384])
|
||||||
|
@pytest.mark.parametrize("topk", [8])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_flashinfer_trtllm_mxint4_moe_wrapper(m, n, k, e, topk):
|
||||||
|
"""Test that the flashinfer_trtllm_mxint4_moe wrapper matches the raw
|
||||||
|
trtllm_mxint4_block_scale_moe kernel call."""
|
||||||
|
pytest.importorskip("flashinfer")
|
||||||
|
from flashinfer import RoutingMethodType
|
||||||
|
from flashinfer.fused_moe import trtllm_mxint4_block_scale_moe
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
|
||||||
|
flashinfer_trtllm_mxint4_moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) * 0.5
|
||||||
|
router_logits = torch.randn((m, e), device="cuda", dtype=torch.float32) * 1.5
|
||||||
|
routing_bias = torch.randn(e, device="cuda", dtype=torch.float32) * 0.8
|
||||||
|
|
||||||
|
std_w1 = (2.0 / (k + 2 * n)) ** 0.5
|
||||||
|
std_w2 = (2.0 / (n + k)) ** 0.5
|
||||||
|
w1_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) * std_w1
|
||||||
|
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=dtype) * std_w2
|
||||||
|
|
||||||
|
w1_int4, w1_scales = mxint4_quantize_moe_weights(w1_bf16)
|
||||||
|
w2_int4, w2_scales = mxint4_quantize_moe_weights(w2_bf16)
|
||||||
|
|
||||||
|
prepared = prepare_static_weights_for_trtllm_mxint4_moe(
|
||||||
|
gemm1_weights=w1_int4,
|
||||||
|
gemm1_scales=w1_scales,
|
||||||
|
gemm2_weights=w2_int4,
|
||||||
|
gemm2_scales=w2_scales,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raw kernel call (reference)
|
||||||
|
raw_out = trtllm_mxint4_block_scale_moe(
|
||||||
|
routing_logits=router_logits.to(torch.float32),
|
||||||
|
routing_bias=routing_bias.to(torch.bfloat16),
|
||||||
|
hidden_states=a,
|
||||||
|
gemm1_weights=prepared["gemm1_weights"].data,
|
||||||
|
gemm1_weights_scale=prepared["gemm1_scales"].data,
|
||||||
|
gemm1_alpha=None,
|
||||||
|
gemm1_beta=None,
|
||||||
|
gemm1_clamp_limit=None,
|
||||||
|
gemm2_weights=prepared["gemm2_weights"].data,
|
||||||
|
gemm2_weights_scale=prepared["gemm2_scales"].data,
|
||||||
|
num_experts=e,
|
||||||
|
top_k=topk,
|
||||||
|
n_group=1,
|
||||||
|
topk_group=1,
|
||||||
|
intermediate_size=n,
|
||||||
|
local_expert_offset=0,
|
||||||
|
local_num_experts=e,
|
||||||
|
routed_scaling_factor=None,
|
||||||
|
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||||
|
enable_pdl=None,
|
||||||
|
output=None,
|
||||||
|
tune_max_num_tokens=8192,
|
||||||
|
)
|
||||||
|
if not isinstance(raw_out, torch.Tensor):
|
||||||
|
raw_out = raw_out[0]
|
||||||
|
raw_out = raw_out.to(dtype)
|
||||||
|
|
||||||
|
# Wrapper call
|
||||||
|
wrapper_out = flashinfer_trtllm_mxint4_moe(
|
||||||
|
x=a,
|
||||||
|
router_logits=router_logits,
|
||||||
|
w13_weight_packed=prepared["gemm1_weights"],
|
||||||
|
w13_weight_scale=prepared["gemm1_scales"],
|
||||||
|
w2_weight_packed=prepared["gemm2_weights"],
|
||||||
|
w2_weight_scale=prepared["gemm2_scales"],
|
||||||
|
global_num_experts=e,
|
||||||
|
top_k=topk,
|
||||||
|
intermediate_size_per_partition=n,
|
||||||
|
local_num_experts=e,
|
||||||
|
ep_rank=0,
|
||||||
|
num_expert_group=1,
|
||||||
|
topk_group=1,
|
||||||
|
e_score_correction_bias=routing_bias,
|
||||||
|
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wrapper_out.shape == (m, k)
|
||||||
|
assert wrapper_out.dtype == dtype
|
||||||
|
torch.testing.assert_close(wrapper_out, raw_out, atol=0.0, rtol=0.0)
|
||||||
|
|||||||
@@ -259,8 +259,12 @@ def flashinfer_trtllm_mxint4_moe(
|
|||||||
routed_scaling_factor=None,
|
routed_scaling_factor=None,
|
||||||
routing_method_type=routing_method_type,
|
routing_method_type=routing_method_type,
|
||||||
enable_pdl=None,
|
enable_pdl=None,
|
||||||
|
do_finalize=True,
|
||||||
output=None,
|
output=None,
|
||||||
tune_max_num_tokens=8192,
|
tune_max_num_tokens=8192,
|
||||||
).to(x.dtype)
|
)
|
||||||
|
if isinstance(out, (tuple, list)):
|
||||||
|
out = out[0]
|
||||||
|
out = out.to(x.dtype)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user