[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
|
||||
|
||||
return rocm_aiter_topk_softmax
|
||||
def dispatch_topk_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
@@ -1121,7 +1120,7 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user