[Kernel][MoE] optimize moe_align_block_size (#29642)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -316,7 +316,11 @@ def fused_marlin_moe(
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, block_size_m, global_num_experts, expert_map
|
||||
topk_ids,
|
||||
block_size_m,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
assert activation is not None
|
||||
|
||||
@@ -1887,7 +1887,11 @@ def fused_experts_impl(
|
||||
)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
@@ -1946,6 +1950,9 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
|
||||
@@ -14,6 +14,7 @@ def moe_align_block_size(
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
ignore_invalid_experts: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
@@ -35,7 +36,13 @@ def moe_align_block_size(
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
should be padded to a multiple of block_size,
|
||||
- ignore_invalid_experts: A flag indicating whether to ignore invalid
|
||||
experts. When False, all expert_ids in topk_ids will participate in
|
||||
counting and ranking, but invalid experts in expert_ids will be marked
|
||||
as -1. When True, all invalid expert_ids in topk_ids will be ignored
|
||||
and will not participate in counting or ranking, and there will be no
|
||||
-1 in expert_ids.
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
@@ -67,6 +74,10 @@ def moe_align_block_size(
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = min(
|
||||
topk_ids.numel() * block_size, max_num_tokens_padded
|
||||
)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
@@ -77,9 +88,16 @@ def moe_align_block_size(
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
ops.moe_align_block_size(
|
||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
expert_map if ignore_invalid_experts else None,
|
||||
)
|
||||
if expert_map is not None:
|
||||
|
||||
if expert_map is not None and not ignore_invalid_experts:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
Reference in New Issue
Block a user