[Bugfix] Handle num_expert_group=None in flashinfer block-scale FP8 MoE (#34494)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -398,3 +398,80 @@ def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
|
||||
assert w13_converted.shape[0] == num_experts
|
||||
assert w2_converted.shape[0] == num_experts
|
||||
|
||||
|
||||
def test_flashinfer_blockscale_fp8_none_expert_group(monkeypatch):
|
||||
"""Test that flashinfer_fused_moe_blockscale_fp8 handles num_expert_group=None.
|
||||
|
||||
Regression test for https://github.com/vllm-project/vllm/issues/34477
|
||||
MiniMax-M2.1 uses sigmoid scoring with e_score_correction_bias but no
|
||||
grouped top-k, resulting in num_expert_group=None. This triggered a crash
|
||||
in the flashinfer kernel when DeepSeekV3 routing was selected.
|
||||
"""
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("Test requires SM >= 100 (Blackwell)")
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
from tests.kernels.quant_utils import native_per_token_group_quant_fp8
|
||||
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
|
||||
e = 16 # num_experts (must be divisible by 4)
|
||||
topk = 6 # top_k > 1 triggers DeepSeekV3 routing with sigmoid
|
||||
m, n, k = 10, 4096, 5120
|
||||
block_shape = [128, 128]
|
||||
block_k = block_shape[1]
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Create BF16 hidden states
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
|
||||
# Create FP8 block-scale quantized weights
|
||||
w13_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
|
||||
|
||||
# Quantize weights per-block to FP8
|
||||
w13_fp8_list, w13_scale_list = [], []
|
||||
w2_fp8_list, w2_scale_list = [], []
|
||||
for i in range(e):
|
||||
wq, ws = native_per_token_group_quant_fp8(w13_bf16[i], block_k)
|
||||
w13_fp8_list.append(wq)
|
||||
w13_scale_list.append(ws)
|
||||
|
||||
wq, ws = native_per_token_group_quant_fp8(w2_bf16[i], block_k)
|
||||
w2_fp8_list.append(wq)
|
||||
w2_scale_list.append(ws)
|
||||
|
||||
w13_fp8 = torch.stack(w13_fp8_list)
|
||||
w13_scale = torch.stack(w13_scale_list)
|
||||
w2_fp8 = torch.stack(w2_fp8_list)
|
||||
w2_scale = torch.stack(w2_scale_list)
|
||||
|
||||
# DeepSeekV3 routing uses float32 logits + optional bias
|
||||
routing_logits = torch.randn((m, e), device="cuda", dtype=torch.float32)
|
||||
routing_bias = torch.randn(e, device="cuda", dtype=torch.float32)
|
||||
|
||||
# This should NOT crash with num_expert_group=None
|
||||
output = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
x=x,
|
||||
w13_weight=w13_fp8,
|
||||
w13_weight_scale_inv=w13_scale,
|
||||
w2_weight=w2_fp8,
|
||||
w2_weight_scale_inv=w2_scale,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=n,
|
||||
expert_offset=0,
|
||||
local_num_experts=e,
|
||||
block_shape=block_shape,
|
||||
routing_method_type=RoutingMethodType.DeepSeekV3,
|
||||
routed_scaling=1.0,
|
||||
)
|
||||
|
||||
assert output is not None
|
||||
assert output.shape == (m, k)
|
||||
|
||||
@@ -201,6 +201,7 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 10
|
||||
|
||||
Reference in New Issue
Block a user