[FEAT] [ROCm] [V1]: Add AITER biased group topk for DeepSeekV3 (#17955)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
vllmellm
2025-05-14 13:03:47 +08:00
committed by GitHub
parent 12e6c0b41c
commit 2d912fb66f
3 changed files with 201 additions and 2 deletions

View File

@@ -17,6 +17,8 @@ from vllm.distributed import (get_dp_group, get_ep_group,
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
@@ -28,6 +30,11 @@ if current_platform.is_cuda_alike():
from .fused_moe import fused_experts
else:
fused_experts = None # type: ignore
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_biased_group_topk as grouped_topk)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed
from .moe_torch_iterative import fused_moe as fused_moe_pallas
@@ -802,8 +809,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk: