[ROCm] Add aiter tkw1 kernel for Llama4 fp8 (#16727)

Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
kliuae
2025-04-22 11:42:34 +08:00
committed by GitHub
parent 0e4254492f
commit 5b794cae8d
6 changed files with 134 additions and 48 deletions

View File

@@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
rocm_aiter_fused_experts,
rocm_aiter_topk_softmax)
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
@@ -846,6 +844,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax
return vllm_topk_softmax
@@ -1102,6 +1101,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts