[Perf][DeepSeek] Add sigmoid+bias fusion to fused_grouped_topk from TRTLLM (#28124)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -1330,24 +1330,37 @@ def fused_grouped_topk(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
if scoring_func == "softmax":
|
||||
if scoring_func == "sigmoid":
|
||||
# Fully fused kernel path for sigmoid
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
gating_output, # raw logits
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
1, # scoring_func=1 for sigmoid
|
||||
)
|
||||
elif scoring_func == "softmax":
|
||||
# Apply softmax in Python, then use fused kernel
|
||||
# TODO: Add support for softmax in kernel
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
scores, # pre-computed scores
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
0, # scoring_func=0 (no activation, scores already computed)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
scores,
|
||||
scores_with_bias.to(scores.dtype),
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
|
||||
# Fused kernel outputs float32 values and int32 indices directly
|
||||
return topk_values, topk_indices
|
||||
|
||||
|
||||
def inplace_fused_experts(
|
||||
|
||||
Reference in New Issue
Block a user