[Kernel][MoE] optimize moe_align_block_size (#29642)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jinzhen Lin
2025-12-07 17:58:47 +08:00
committed by GitHub
parent 1b0482b9d1
commit 879ddb09c3
10 changed files with 195 additions and 63 deletions

View File

@@ -316,7 +316,11 @@ def fused_marlin_moe(
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, global_num_experts, expert_map
topk_ids,
block_size_m,
global_num_experts,
expert_map,
ignore_invalid_experts=True,
)
assert activation is not None

View File

@@ -1887,7 +1887,11 @@ def fused_experts_impl(
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
curr_topk_ids,
config["BLOCK_SIZE_M"],
global_num_experts,
expert_map,
ignore_invalid_experts=True,
)
invoke_fused_moe_kernel(
@@ -1946,6 +1950,9 @@ def fused_experts_impl(
block_shape=block_shape,
)
if expert_map is not None:
intermediate_cache3.zero_()
invoke_fused_moe_kernel(
qintermediate_cache2,
w2,

View File

@@ -14,6 +14,7 @@ def moe_align_block_size(
num_experts: int,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
ignore_invalid_experts: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
@@ -35,7 +36,13 @@ def moe_align_block_size(
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
should be padded to a multiple of block_size,
- ignore_invalid_experts: A flag indicating whether to ignore invalid
experts. When False, all expert_ids in topk_ids will participate in
counting and ranking, but invalid experts in expert_ids will be marked
as -1. When True, all invalid expert_ids in topk_ids will be ignored
and will not participate in counting or ranking, and there will be no
-1 in expert_ids.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
@@ -67,6 +74,10 @@ def moe_align_block_size(
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
max_num_tokens_padded = min(
topk_ids.numel() * block_size, max_num_tokens_padded
)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
@@ -77,9 +88,16 @@ def moe_align_block_size(
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
ops.moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
expert_map if ignore_invalid_experts else None,
)
if expert_map is not None:
if expert_map is not None and not ignore_invalid_experts:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad