permute/unpermute kernel for moe optimization (#14568)
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
@@ -854,7 +854,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
@@ -868,20 +868,19 @@ def fused_topk(
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
token_expert_indices = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
|
||||
token_expert_indicies,
|
||||
token_expert_indices,
|
||||
gating_output_float, renormalize)
|
||||
|
||||
del token_expert_indicies # Not used. Will be used in the future.
|
||||
return topk_weights, topk_ids
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@@ -1510,8 +1509,8 @@ def fused_moe(
|
||||
topk, renormalize,
|
||||
num_expert_group, topk_group)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
|
||||
Reference in New Issue
Block a user