[Kernel] Support bias type in grouped_topk kernel (#31781)

Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Xin Yang
2026-01-07 12:16:32 -08:00
committed by GitHub
parent c907d22158
commit 0ada960a20
3 changed files with 104 additions and 72 deletions

View File

@@ -1627,7 +1627,7 @@ def fused_grouped_topk(
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype),
e_score_correction_bias,
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
@@ -1641,7 +1641,7 @@ def fused_grouped_topk(
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype),
e_score_correction_bias,
0, # scoring_func=0 (no activation, scores already computed)
)
else: