[ROCM] Optimize the fused_topk_bias to use aiter instead of fallback torch ops. (#36253)
Signed-off-by: zhutaoyu <zhutaoyu97@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
@@ -57,6 +58,19 @@ def vllm_topk_sigmoid(
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=8)
|
||||
def _aiter_get_num_expert_group(num_experts: int) -> int:
|
||||
_AITER_MAX_EXPERTS_PER_GROUP = 32
|
||||
g = max(1, -(-num_experts // _AITER_MAX_EXPERTS_PER_GROUP))
|
||||
while num_experts % g != 0:
|
||||
g += 1
|
||||
assert num_experts % g == 0, f"{num_experts=} not divisible by {g=}"
|
||||
assert num_experts // g <= _AITER_MAX_EXPERTS_PER_GROUP, (
|
||||
f"group size {num_experts // g} exceeds limit {_AITER_MAX_EXPERTS_PER_GROUP}"
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def fused_topk_bias(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -108,6 +122,30 @@ def fused_topk_bias(
|
||||
return topk_weights, topk_ids
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
elif rocm_aiter_ops.is_fused_moe_enabled() and scoring_func == "sigmoid":
|
||||
M = hidden_states.size(0)
|
||||
num_experts = gating_output.shape[-1]
|
||||
num_expert_group = _aiter_get_num_expert_group(num_experts)
|
||||
if topk >= num_expert_group:
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M,
|
||||
topk,
|
||||
dtype=torch.int32 if indices_type is None else indices_type,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
rocm_aiter_ops.biased_grouped_topk(
|
||||
gating_output,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=num_expert_group,
|
||||
need_renorm=renormalize,
|
||||
)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
n_routed_experts = gating_output.shape[-1]
|
||||
if scoring_func == "softmax":
|
||||
|
||||
Reference in New Issue
Block a user