[Kernel][Performance] Fuse float cast and renormalize to topk softmax kernel (#26717)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
This commit is contained in:
zhrrr
2025-10-17 15:30:35 +08:00
committed by GitHub
parent 5550ff9c25
commit 75c7ad9918
5 changed files with 221 additions and 94 deletions

View File

@@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results