[Perf][DeepSeek] Add sigmoid+bias fusion to fused_grouped_topk from TRTLLM (#28124)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Michael Goin
2025-11-08 10:20:55 +08:00
committed by GitHub
parent 61d25dc44b
commit 0852527647
5 changed files with 149 additions and 75 deletions

View File

@@ -1330,24 +1330,37 @@ def fused_grouped_topk(
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax":
if scoring_func == "sigmoid":
# Fully fused kernel path for sigmoid
topk_values, topk_indices = ops.grouped_topk(
gating_output, # raw logits
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype),
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
topk_values, topk_indices = ops.grouped_topk(
scores, # pre-computed scores
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype),
0, # scoring_func=0 (no activation, scores already computed)
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
topk_values, topk_indices = ops.grouped_topk(
scores,
scores_with_bias.to(scores.dtype),
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
)
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
# Fused kernel outputs float32 values and int32 indices directly
return topk_values, topk_indices
def inplace_fused_experts(