[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support (#17110)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
@@ -8,10 +8,8 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
dispatch_fused_experts_func, dispatch_topk_func,
|
||||
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
|
||||
vllm_topk_softmax)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
|
||||
vllm_topk_softmax)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
@@ -142,24 +140,6 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
assert topk_func == vllm_topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
fused_experts_func = dispatch_fused_experts_func(inplace)
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts)
|
||||
assert fused_experts_func == rocm_aiter_fused_experts
|
||||
elif inplace:
|
||||
assert fused_experts_func == torch_vllm_inplace_fused_experts
|
||||
else:
|
||||
assert fused_experts_func == torch_vllm_outplace_fused_experts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
|
||||
Reference in New Issue
Block a user