[Kernel][MoE] optimize moe_align_block_size (#29642)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -955,9 +955,22 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
@pytest.mark.parametrize("ep_size", [1, 2])
|
||||
def test_moe_align_block_size_opcheck(ep_size):
|
||||
num_experts = 4
|
||||
block_size = 4
|
||||
|
||||
expert_map = None
|
||||
if ep_size != 1:
|
||||
local_num_experts = num_experts // ep_size
|
||||
expert_ids = torch.randint(
|
||||
0, num_experts, (local_num_experts,), device="cuda", dtype=torch.int32
|
||||
)
|
||||
expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
expert_map[expert_ids] = torch.arange(
|
||||
local_num_experts, device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
@@ -980,6 +993,7 @@ def test_moe_align_block_size_opcheck():
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
expert_map,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -106,6 +106,8 @@ def torch_moe_align_block_size(
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = topk_ids.numel() * block_size
|
||||
|
||||
flattened_token_indices = torch.arange(
|
||||
topk_ids.numel(), device=topk_ids.device, dtype=torch.int32
|
||||
@@ -126,6 +128,8 @@ def torch_moe_align_block_size(
|
||||
)
|
||||
for expert_id in range(num_experts):
|
||||
original_count = expert_token_counts[expert_id]
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
if original_count > 0:
|
||||
expert_padded_counts[expert_id] = (
|
||||
(original_count + block_size - 1) // block_size
|
||||
@@ -143,6 +147,9 @@ def torch_moe_align_block_size(
|
||||
current_pos = 0
|
||||
current_block = 0
|
||||
for expert_id in range(num_experts):
|
||||
if expert_map is not None and expert_map[expert_id] == -1:
|
||||
continue
|
||||
|
||||
expert_mask = sorted_expert_ids == expert_id
|
||||
expert_tokens = sorted_token_indices[expert_mask]
|
||||
num_expert_tokens = expert_tokens.shape[0]
|
||||
@@ -153,7 +160,13 @@ def torch_moe_align_block_size(
|
||||
)
|
||||
|
||||
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
|
||||
expert_ids[current_block : current_block + expert_blocks_needed] = expert_id
|
||||
|
||||
expert_id_new = expert_id
|
||||
if expert_map is not None:
|
||||
expert_id_new = expert_map[expert_id]
|
||||
expert_ids[current_block : current_block + expert_blocks_needed] = (
|
||||
expert_id_new
|
||||
)
|
||||
|
||||
current_pos += expert_padded_counts[expert_id]
|
||||
current_block += expert_blocks_needed
|
||||
@@ -163,8 +176,6 @@ def torch_moe_align_block_size(
|
||||
[total_padded_tokens], dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
return sorted_token_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
@@ -229,9 +240,9 @@ def test_moe_align_block_size(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32])
|
||||
@pytest.mark.parametrize("m", [16, 32, 2048])
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
@pytest.mark.parametrize("num_experts", [8])
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size_with_expert_map(
|
||||
@@ -253,6 +264,7 @@ def test_moe_align_block_size_with_expert_map(
|
||||
block_size=block_size,
|
||||
num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
|
||||
torch_moe_align_block_size(
|
||||
|
||||
Reference in New Issue
Block a user