[Bugfix] Fix fused MoE int32 overflow in stride*offset without perf regression (#34507)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
haosdent
2026-02-17 09:58:49 +08:00
committed by GitHub
parent 0b5f9b7204
commit b68fd899d1
2 changed files with 54 additions and 1 deletions

View File

@@ -396,6 +396,55 @@ def test_fused_moe(
)
def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
"""Regression test for int32 overflow in stride*offset products.
When chunking is disabled and M is large, stride_cm * offs_token can
exceed int32 max. Verifies the offs_token int64 cast (fix for #34413)
prevents overflow and produces correct results.
Reproduces the scenario from PR #34279.
"""
# ~12 GB GPU memory needed for intermediate caches
free_mem = torch.cuda.mem_get_info()[0]
if free_mem < 12 * 1024**3:
pytest.skip("Insufficient GPU memory for overflow test")
set_random_seed(7)
m, n, k, e, topk = 100000, 2048, 1024, 8, 6
dtype = torch.bfloat16
# Disable chunking to expose the overflow-prone code path
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "10000000")
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
# Verify the test exercises the overflow condition:
# C has shape (M, topk, N) where N = w1.size(1) = 2*n
# stride_cm = C.stride(1) = N, max offs_token = M * topk
# Product must exceed int32 max for this test to be meaningful
N = w1.size(1)
assert N * m * topk > 2**31 - 1, "Test params don't trigger int32 overflow"
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
with set_current_vllm_config(vllm_config):
run_moe_test(
torch_moe,
fused_moe_fn,
a=a,
w1=w1,
w2=w2,
score=score,
topk=topk,
global_num_experts=e,
)
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS_SMALL_M)
@pytest.mark.parametrize("e", NUM_EXPERTS_LARGE)
@pytest.mark.parametrize("topk", TOP_KS_SMALL)

View File

@@ -175,7 +175,8 @@ def fused_moe_kernel_gptq_awq(
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
# Cast to int64 to prevent overflow in stride*offset products
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
@@ -426,6 +427,9 @@ def fused_moe_kernel(
pid_m, # first element = pid_m
num_valid_tokens, # remaining elements = constant
)
# Cast to int64 to prevent overflow in stride*offset products
# (e.g. stride_cm * offs_token can exceed int32 for large token counts)
offs_token = offs_token.to(tl.int64)
token_mask = offs_token < num_valid_tokens