[Kernel][Performance] Fuse float cast and renormalize to topk softmax kernel (#26717)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
This commit is contained in:
zhrrr
2025-10-17 15:30:35 +08:00
committed by GitHub
parent 5550ff9c25
commit 75c7ad9918
5 changed files with 221 additions and 94 deletions

View File

@@ -1074,9 +1074,8 @@ def vllm_topk_softmax(
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
@@ -1113,11 +1112,9 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices