[Bugfix] Fix fused MoE int32 overflow in stride*offset without perf regression (#34507)
Signed-off-by: haosdent <haosdent@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -396,6 +396,55 @@ def test_fused_moe(
|
||||
)
|
||||
|
||||
|
||||
def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
|
||||
"""Regression test for int32 overflow in stride*offset products.
|
||||
|
||||
When chunking is disabled and M is large, stride_cm * offs_token can
|
||||
exceed int32 max. Verifies the offs_token int64 cast (fix for #34413)
|
||||
prevents overflow and produces correct results.
|
||||
|
||||
Reproduces the scenario from PR #34279.
|
||||
"""
|
||||
# ~12 GB GPU memory needed for intermediate caches
|
||||
free_mem = torch.cuda.mem_get_info()[0]
|
||||
if free_mem < 12 * 1024**3:
|
||||
pytest.skip("Insufficient GPU memory for overflow test")
|
||||
|
||||
set_random_seed(7)
|
||||
|
||||
m, n, k, e, topk = 100000, 2048, 1024, 8, 6
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Disable chunking to expose the overflow-prone code path
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "10000000")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
# Verify the test exercises the overflow condition:
|
||||
# C has shape (M, topk, N) where N = w1.size(1) = 2*n
|
||||
# stride_cm = C.stride(1) = N, max offs_token = M * topk
|
||||
# Product must exceed int32 max for this test to be meaningful
|
||||
N = w1.size(1)
|
||||
assert N * m * topk > 2**31 - 1, "Test params don't trigger int32 overflow"
|
||||
|
||||
fused_moe_fn = functools.partial(fused_moe, renormalize=False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
run_moe_test(
|
||||
torch_moe,
|
||||
fused_moe_fn,
|
||||
a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
score=score,
|
||||
topk=topk,
|
||||
global_num_experts=e,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS_SMALL_M)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS_LARGE)
|
||||
@pytest.mark.parametrize("topk", TOP_KS_SMALL)
|
||||
|
||||
@@ -175,7 +175,8 @@ def fused_moe_kernel_gptq_awq(
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
# Cast to int64 to prevent overflow in stride*offset products
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
@@ -426,6 +427,9 @@ def fused_moe_kernel(
|
||||
pid_m, # first element = pid_m
|
||||
num_valid_tokens, # remaining elements = constant
|
||||
)
|
||||
# Cast to int64 to prevent overflow in stride*offset products
|
||||
# (e.g. stride_cm * offs_token can exceed int32 for large token counts)
|
||||
offs_token = offs_token.to(tl.int64)
|
||||
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user