[Bug] Fix moe_sum signature (#18440)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-05-21 01:37:08 -04:00
committed by GitHub
parent 0c15c2e486
commit 92247c522e
2 changed files with 19 additions and 1 deletions

View File

@@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck():
opcheck(torch.ops._moe_C.moe_align_block_size,
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
num_tokens_post_pad))
@pytest.mark.parametrize("m", [1, 33, 222, 1024 * 128])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
actual = torch.empty((m, k), device="cuda", dtype=dtype)
expected = input.sum(dim=1)
torch.ops._moe_C.moe_sum(input, actual)
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
opcheck(torch.ops._moe_C.moe_sum, (input, actual))