[Bug] Fix moe_sum signature (#18440)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -575,3 +575,21 @@ def test_moe_align_block_size_opcheck():
|
||||
opcheck(torch.ops._moe_C.moe_align_block_size,
|
||||
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
||||
num_tokens_post_pad))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 222, 1024 * 128])
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
||||
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
|
||||
actual = torch.empty((m, k), device="cuda", dtype=dtype)
|
||||
|
||||
expected = input.sum(dim=1)
|
||||
torch.ops._moe_C.moe_sum(input, actual)
|
||||
|
||||
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
|
||||
|
||||
opcheck(torch.ops._moe_C.moe_sum, (input, actual))
|
||||
|
||||
Reference in New Issue
Block a user