diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index eddc395cc..eb3d9f8a8 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -396,6 +396,55 @@ def test_fused_moe( ) +def test_fused_moe_int64_overflow(monkeypatch, workspace_init): + """Regression test for int32 overflow in stride*offset products. + + When chunking is disabled and M is large, stride_cm * offs_token can + exceed int32 max. Verifies the offs_token int64 cast (fix for #34413) + prevents overflow and produces correct results. + + Reproduces the scenario from PR #34279. + """ + # ~12 GB GPU memory needed for intermediate caches + free_mem = torch.cuda.mem_get_info()[0] + if free_mem < 12 * 1024**3: + pytest.skip("Insufficient GPU memory for overflow test") + + set_random_seed(7) + + m, n, k, e, topk = 100000, 2048, 1024, 8, 6 + dtype = torch.bfloat16 + + # Disable chunking to expose the overflow-prone code path + monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "10000000") + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + # Verify the test exercises the overflow condition: + # C has shape (M, topk, N) where N = w1.size(1) = 2*n + # stride_cm = C.stride(1) = N, max offs_token = M * topk + # Product must exceed int32 max for this test to be meaningful + N = w1.size(1) + assert N * m * topk > 2**31 - 1, "Test params don't trigger int32 overflow" + + fused_moe_fn = functools.partial(fused_moe, renormalize=False) + + with set_current_vllm_config(vllm_config): + run_moe_test( + torch_moe, + fused_moe_fn, + a=a, + w1=w1, + w2=w2, + score=score, + topk=topk, + global_num_experts=e, + ) + + @pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS_SMALL_M) @pytest.mark.parametrize("e", NUM_EXPERTS_LARGE) @pytest.mark.parametrize("topk", TOP_KS_SMALL) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5240f79be..a80978772 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -175,7 +175,8 @@ def fused_moe_kernel_gptq_awq( if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + # Cast to int64 to prevent overflow in stride*offset products + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) token_mask = offs_token < num_valid_tokens off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) @@ -426,6 +427,9 @@ def fused_moe_kernel( pid_m, # first element = pid_m num_valid_tokens, # remaining elements = constant ) + # Cast to int64 to prevent overflow in stride*offset products + # (e.g. stride_cm * offs_token can exceed int32 for large token counts) + offs_token = offs_token.to(tl.int64) token_mask = offs_token < num_valid_tokens