[Kernel][MoE] optimize moe_align_block_size (#29642)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -955,9 +955,22 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
@pytest.mark.parametrize("ep_size", [1, 2])
|
||||
def test_moe_align_block_size_opcheck(ep_size):
|
||||
num_experts = 4
|
||||
block_size = 4
|
||||
|
||||
expert_map = None
|
||||
if ep_size != 1:
|
||||
local_num_experts = num_experts // ep_size
|
||||
expert_ids = torch.randint(
|
||||
0, num_experts, (local_num_experts,), device="cuda", dtype=torch.int32
|
||||
)
|
||||
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
expert_map[expert_ids] = torch.arange(
|
||||
local_num_experts, device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
@@ -980,6 +993,7 @@ def test_moe_align_block_size_opcheck():
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
expert_map,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user