[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
@@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
@@ -412,14 +415,12 @@ def test_mixtral_moe(
|
||||
huggingface."""
|
||||
|
||||
# clear the cache before every test
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
# Force reload aiter_ops to pick up the new environment variables.
|
||||
if "rocm_aiter_ops" in sys.modules:
|
||||
importlib.reload(rocm_aiter_ops)
|
||||
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
if dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user