permute/unpermute kernel for moe optimization (#14568)

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
Caleb_Du
2025-05-03 02:31:55 +08:00
committed by GitHub
parent 0f87d8f7b2
commit 3e887d2e0c
19 changed files with 1474 additions and 28 deletions

View File

@@ -854,7 +854,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@@ -868,20 +868,19 @@ def fused_topk(
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indices = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
token_expert_indicies,
token_expert_indices,
gating_output_float, renormalize)
del token_expert_indicies # Not used. Will be used in the future.
return topk_weights, topk_ids
return topk_weights, topk_ids, token_expert_indices
# This is used by the Deepseek-V2 and Deepseek-V3 model
@@ -1510,8 +1509,8 @@ def fused_moe(
topk, renormalize,
num_expert_group, topk_group)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)