[Kernel] Add fused grouped_topk kernel for MoE (#23274)

Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Xin Yang
2025-08-25 11:47:52 -07:00
committed by GitHub
parent 2a167b2eeb
commit 8a3cd90af5
8 changed files with 909 additions and 2 deletions

View File

@@ -949,8 +949,23 @@ def grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \
current_platform.is_cuda() and \
num_expert_group <= 32 and topk <= 32 and \
e_score_correction_bias is not None:
return fused_grouped_topk(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor)
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
@@ -996,9 +1011,38 @@ def grouped_topk(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights * routed_scaling_factor
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
e_score_correction_bias: torch.Tensor,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
topk_values, topk_indices = ops.grouped_topk(
scores, scores_with_bias.to(scores.dtype), num_expert_group,
topk_group, topk, renormalize, routed_scaling_factor)
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
def get_config_dtype_str(
dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False,