[Kernel][Performance] Fuse float cast and renormalize to topk softmax kernel (#26717)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
This commit is contained in:
@@ -1074,9 +1074,8 @@ def vllm_topk_softmax(
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@@ -1113,11 +1112,9 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
|
||||
Reference in New Issue
Block a user