[Kernel] Add topk_sigmoid kernel (#31246)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake(
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_topk_sigmoid_impl(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
) -> None:
|
||||
from aiter import topk_sigmoid
|
||||
|
||||
topk_sigmoid(topk_weights, topk_indices, gating_output)
|
||||
|
||||
|
||||
def _rocm_aiter_topk_sigmoid_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_biased_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
@@ -985,6 +1003,14 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_topk_sigmoid",
|
||||
op_func=_rocm_aiter_topk_sigmoid_impl,
|
||||
mutates_args=["topk_weights", "topk_indices"],
|
||||
fake_impl=_rocm_aiter_topk_sigmoid_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_biased_grouped_topk",
|
||||
op_func=_rocm_aiter_biased_grouped_topk_impl,
|
||||
@@ -1272,6 +1298,19 @@ class rocm_aiter_ops:
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@staticmethod
|
||||
def topk_sigmoid(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
torch.ops.vllm.rocm_aiter_topk_sigmoid(
|
||||
topk_weights, topk_indices, gating_output
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@staticmethod
|
||||
def biased_grouped_topk(
|
||||
gating_output: torch.Tensor,
|
||||
|
||||
@@ -2177,9 +2177,33 @@ def topk_softmax(
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
torch.ops._moe_C.topk_softmax(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
def topk_sigmoid(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
torch.ops._moe_C.topk_sigmoid(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -106,14 +106,14 @@ def _quant_flags_to_group_shape(
|
||||
class RoutingMethodType(IntEnum):
|
||||
# Default: Softmax -> TopK
|
||||
Default = (0,)
|
||||
# Renormalize: TopK -> Softmax
|
||||
# Renormalize: TopK -> Softmax/Sigmoid
|
||||
Renormalize = (1,)
|
||||
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
|
||||
# -> Top8 experts from the Top4 groups
|
||||
DeepSeekV3 = (2,)
|
||||
# Llama4: Top1 -> Sigmoid
|
||||
Llama4 = (3,)
|
||||
# RenormalizeNaive: Softmax -> TopK -> Renormalize
|
||||
# RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize
|
||||
RenormalizeNaive = (4,)
|
||||
# TopK: TopK (no softmax)
|
||||
TopK = (5,)
|
||||
|
||||
@@ -4,6 +4,8 @@ from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
|
||||
|
||||
|
||||
def vllm_topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def vllm_topk_sigmoid(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_sigmoid(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
scoring_func: str = "softmax",
|
||||
indices_type: torch.dtype | None = None,
|
||||
):
|
||||
if not rocm_aiter_ops.is_fused_moe_enabled():
|
||||
assert hidden_states.size(0) == gating_output.size(0), (
|
||||
"Number of tokens mismatch"
|
||||
)
|
||||
|
||||
M, _ = hidden_states.size()
|
||||
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = vllm_topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights, topk_ids = vllm_topk_sigmoid(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
if scoring_func == "softmax":
|
||||
scores = gating_output.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
scores = gating_output.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
@@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
global_num_experts: int,
|
||||
eplb_state: EplbLayerState,
|
||||
e_score_correction_bias: torch.Tensor,
|
||||
scoring_func: str,
|
||||
renormalize: bool = True,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
enable_eplb: bool = False,
|
||||
@@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
)
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
self.renormalize = renormalize
|
||||
self.scoring_func = scoring_func
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
@property
|
||||
@@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter):
|
||||
e_score_correction_bias=self.e_score_correction_bias.data,
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
scoring_func=self.scoring_func,
|
||||
)
|
||||
|
||||
if self.routed_scaling_factor != 1.0:
|
||||
|
||||
@@ -16,7 +16,7 @@ def vllm_topk_softmax(
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
renormalize: bool = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
@@ -29,7 +29,25 @@ def vllm_topk_softmax(
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func(
|
||||
def vllm_topk_sigmoid(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
ops.topk_sigmoid(
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
token_expert_indices,
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_softmax_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
@@ -37,12 +55,21 @@ def dispatch_topk_func(
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
def dispatch_topk_sigmoid_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_sigmoid
|
||||
return vllm_topk_sigmoid
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
indices_type: torch.dtype | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
|
||||
|
||||
@@ -61,12 +88,26 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
if scoring_func == "softmax":
|
||||
topk_func = dispatch_topk_softmax_func(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_func = dispatch_topk_sigmoid_func(
|
||||
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids, token_expert_indices
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
|
||||
class FusedTopKRouter(BaseRouter):
|
||||
@@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter):
|
||||
enable_eplb: bool = False,
|
||||
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
|
||||
):
|
||||
assert scoring_func == "softmax", "FusedTopKRouter only supports softmax."
|
||||
super().__init__(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
@@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter):
|
||||
indices_type_getter=indices_type_getter,
|
||||
)
|
||||
self.renormalize = renormalize
|
||||
self.scoring_func = scoring_func
|
||||
|
||||
@property
|
||||
def routing_method_type(self) -> RoutingMethodType:
|
||||
@@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter):
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
indices_type=indices_type,
|
||||
scoring_func=self.scoring_func,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -143,17 +143,13 @@ def create_fused_moe_router(
|
||||
router.capture = capture
|
||||
return router
|
||||
|
||||
if scoring_func != "softmax":
|
||||
raise ValueError(
|
||||
"Only softmax scoring function is supported for non-grouped topk."
|
||||
)
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
router = FusedTopKBiasRouter(
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
eplb_state=eplb_state,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
scoring_func=scoring_func,
|
||||
renormalize=renormalize,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
enable_eplb=enable_eplb,
|
||||
|
||||
@@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module):
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
scoring_func=config.scoring_func,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=1,
|
||||
topk_group=1,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
|
||||
Reference in New Issue
Block a user