[Kernel] Support bias type in grouped_topk kernel (#31781)
Signed-off-by: Xin Yang <xyangx@amazon.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -1627,7 +1627,7 @@ def fused_grouped_topk(
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
e_score_correction_bias,
|
||||
1, # scoring_func=1 for sigmoid
|
||||
)
|
||||
elif scoring_func == "softmax":
|
||||
@@ -1641,7 +1641,7 @@ def fused_grouped_topk(
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
e_score_correction_bias,
|
||||
0, # scoring_func=0 (no activation, scores already computed)
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user