[ROCM] Optimize the fused_topk_bias to use aiter instead of fallback torch ops. (#36253)

Signed-off-by: zhutaoyu <zhutaoyu97@gmail.com>
This commit is contained in:
Taoyu Zhu
2026-03-10 00:30:35 +08:00
committed by GitHub
parent 74a9f54cdb
commit 70485a11bd

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
import torch
@@ -57,6 +58,19 @@ def vllm_topk_sigmoid(
return topk_weights, topk_indices
@functools.lru_cache(maxsize=8)
def _aiter_get_num_expert_group(num_experts: int) -> int:
_AITER_MAX_EXPERTS_PER_GROUP = 32
g = max(1, -(-num_experts // _AITER_MAX_EXPERTS_PER_GROUP))
while num_experts % g != 0:
g += 1
assert num_experts % g == 0, f"{num_experts=} not divisible by {g=}"
assert num_experts // g <= _AITER_MAX_EXPERTS_PER_GROUP, (
f"group size {num_experts // g} exceeds limit {_AITER_MAX_EXPERTS_PER_GROUP}"
)
return g
def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -108,6 +122,30 @@ def fused_topk_bias(
return topk_weights, topk_ids
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid":
M = hidden_states.size(0)
num_experts = gating_output.shape[-1]
num_expert_group = _aiter_get_num_expert_group(num_experts)
if topk >= num_expert_group:
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)
rocm_aiter_ops.biased_grouped_topk(
gating_output,
e_score_correction_bias.to(gating_output.dtype),
topk_weights,
topk_ids,
num_expert_group=num_expert_group,
topk_group=num_expert_group,
need_renorm=renormalize,
)
return topk_weights, topk_ids
n_routed_experts = gating_output.shape[-1]
if scoring_func == "softmax":