[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
vllmellm
2025-11-10 17:20:53 +01:00
committed by GitHub
parent 40e2eeeb92
commit f080a83511
25 changed files with 1193 additions and 924 deletions

View File

@@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`.
"""
import functools
import importlib
import sys
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
@@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
@@ -412,14 +415,12 @@ def test_mixtral_moe(
huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
# Force reload aiter_ops to pick up the new environment variables.
if "rocm_aiter_ops" in sys.modules:
importlib.reload(rocm_aiter_ops)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")