permute/unpermute kernel for moe optimization (#14568)

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
Caleb_Du
2025-05-03 02:31:55 +08:00
committed by GitHub
parent 0f87d8f7b2
commit 3e887d2e0c
19 changed files with 1474 additions and 28 deletions

View File

@@ -175,10 +175,8 @@ class ArcticMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
do_normalize = self.top_k > 1
topk_weights, topk_ids = fused_topk(hidden_states,
router_logits,
self.top_k,
renormalize=do_normalize)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, router_logits, self.top_k, renormalize=do_normalize)
# topk_ids: (num_tokens, k)
if self.is_quant:
if 2 * num_tokens <= self.num_experts: