[Kernel][MoE] optimize moe_align_block_size (#29642)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jinzhen Lin
2025-12-07 17:58:47 +08:00
committed by GitHub
parent 1b0482b9d1
commit 879ddb09c3
10 changed files with 195 additions and 63 deletions

View File

@@ -955,9 +955,22 @@ def test_fused_marlin_moe_with_bias(m):
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
def test_moe_align_block_size_opcheck():
@pytest.mark.parametrize("ep_size", [1, 2])
def test_moe_align_block_size_opcheck(ep_size):
num_experts = 4
block_size = 4
expert_map = None
if ep_size != 1:
local_num_experts = num_experts // ep_size
expert_ids = torch.randint(
0, num_experts, (local_num_experts,), device="cuda", dtype=torch.int32
)
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
expert_map[expert_ids] = torch.arange(
local_num_experts, device="cuda", dtype=torch.int32
)
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
@@ -980,6 +993,7 @@ def test_moe_align_block_size_opcheck():
sorted_ids,
expert_ids,
num_tokens_post_pad,
expert_map,
),
)