[Kernel] Add fused grouped_topk kernel for MoE (#23274)
Signed-off-by: Xin Yang <xyangx@amazon.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -949,8 +949,23 @@ def grouped_topk(
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \
|
||||
current_platform.is_cuda() and \
|
||||
num_expert_group <= 32 and topk <= 32 and \
|
||||
e_score_correction_bias is not None:
|
||||
return fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor)
|
||||
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch")
|
||||
@@ -996,9 +1011,38 @@ def grouped_topk(
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
topk_weights = topk_weights * routed_scaling_factor
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def fused_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
if scoring_func == "softmax":
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
topk_values, topk_indices = ops.grouped_topk(
|
||||
scores, scores_with_bias.to(scores.dtype), num_expert_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor)
|
||||
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
|
||||
Reference in New Issue
Block a user