[Kernel] Add topk_sigmoid kernel (#31246)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-01-21 14:49:51 -08:00
committed by GitHub
parent e675dda67b
commit 63227accf5
13 changed files with 725 additions and 126 deletions

View File

@@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake(
pass
def _rocm_aiter_topk_sigmoid_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
gating_output: torch.Tensor,
) -> None:
from aiter import topk_sigmoid
topk_sigmoid(topk_weights, topk_indices, gating_output)
def _rocm_aiter_topk_sigmoid_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
gating_output: torch.Tensor,
) -> None:
pass
def _rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
@@ -985,6 +1003,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_sigmoid",
op_func=_rocm_aiter_topk_sigmoid_impl,
mutates_args=["topk_weights", "topk_indices"],
fake_impl=_rocm_aiter_topk_sigmoid_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=_rocm_aiter_biased_grouped_topk_impl,
@@ -1272,6 +1298,19 @@ class rocm_aiter_ops:
)
return topk_weights, topk_indices
@staticmethod
def topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_sigmoid(
topk_weights, topk_indices, gating_output
)
return topk_weights, topk_indices
@staticmethod
def biased_grouped_topk(
gating_output: torch.Tensor,

View File

@@ -2177,9 +2177,33 @@ def topk_softmax(
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.topk_softmax(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
def topk_sigmoid(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.topk_sigmoid(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)

View File

@@ -106,14 +106,14 @@ def _quant_flags_to_group_shape(
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
# Renormalize: TopK -> Softmax/Sigmoid
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
# RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)

View File

@@ -4,6 +4,8 @@ from collections.abc import Callable
import torch
import vllm._custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
def vllm_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_indices
def vllm_topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
ops.topk_sigmoid(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_indices
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "softmax",
indices_type: torch.dtype | None = None,
):
if not rocm_aiter_ops.is_fused_moe_enabled():
assert hidden_states.size(0) == gating_output.size(0), (
"Number of tokens mismatch"
)
M, _ = hidden_states.size()
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
if scoring_func == "softmax":
topk_weights, topk_ids = vllm_topk_softmax(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_ids
elif scoring_func == "sigmoid":
topk_weights, topk_ids = vllm_topk_sigmoid(
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
e_score_correction_bias,
)
return topk_weights, topk_ids
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
if scoring_func == "softmax":
scores = gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
@@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter):
global_num_experts: int,
eplb_state: EplbLayerState,
e_score_correction_bias: torch.Tensor,
scoring_func: str,
renormalize: bool = True,
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
@@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter):
)
self.e_score_correction_bias = e_score_correction_bias
self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
@property
@@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter):
e_score_correction_bias=self.e_score_correction_bias.data,
topk=self.top_k,
renormalize=self.renormalize,
scoring_func=self.scoring_func,
)
if self.routed_scaling_factor != 1.0:

View File

@@ -16,7 +16,7 @@ def vllm_topk_softmax(
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
renormalize: bool = False,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
@@ -29,7 +29,25 @@ def vllm_topk_softmax(
return topk_weights, topk_indices
def dispatch_topk_func(
def vllm_topk_sigmoid(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
) -> tuple[torch.Tensor, ...]:
ops.topk_sigmoid(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
def dispatch_topk_softmax_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
@@ -37,12 +55,21 @@ def dispatch_topk_func(
return vllm_topk_softmax
def dispatch_topk_sigmoid_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_sigmoid
return vllm_topk_sigmoid
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
indices_type: torch.dtype | None = None,
scoring_func: str = "softmax",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
@@ -61,12 +88,26 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
if scoring_func == "softmax":
topk_func = dispatch_topk_softmax_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
return topk_weights, topk_ids, token_expert_indices
elif scoring_func == "sigmoid":
topk_func = dispatch_topk_sigmoid_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
class FusedTopKRouter(BaseRouter):
@@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter):
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
assert scoring_func == "softmax", "FusedTopKRouter only supports softmax."
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
@@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter):
indices_type_getter=indices_type_getter,
)
self.renormalize = renormalize
self.scoring_func = scoring_func
@property
def routing_method_type(self) -> RoutingMethodType:
@@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter):
topk=self.top_k,
renormalize=self.renormalize,
indices_type=indices_type,
scoring_func=self.scoring_func,
)
return topk_weights, topk_ids

View File

@@ -143,17 +143,13 @@ def create_fused_moe_router(
router.capture = capture
return router
if scoring_func != "softmax":
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
)
if e_score_correction_bias is not None:
router = FusedTopKBiasRouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
e_score_correction_bias=e_score_correction_bias,
scoring_func=scoring_func,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
enable_eplb=enable_eplb,

View File

@@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module):
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
scoring_func=config.scoring_func,
use_grouped_topk=True,
num_expert_group=1,
topk_group=1,
e_score_correction_bias=self.e_score_correction_bias,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,