diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py index d4724d749..070d00f61 100644 --- a/tests/kernels/moe/test_rocm_aiter_topk.py +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -10,22 +10,28 @@ # and the platform is not ROCm. import importlib.util +import os import pytest import torch +from vllm.platforms import current_platform + +if not current_platform.is_rocm(): + pytest.skip("This test can only run on ROCm.", allow_module_level=True) + +# This environment variable must be set so ops will be registered. +os.environ["VLLM_ROCM_USE_AITER"] = "1" + # this import statement is needed to ensure the ops are registered import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401 -from vllm.platforms import current_platform # need to import once to ensure the ops are registered # Check if aiter package is installed aiter_available = importlib.util.find_spec("aiter") is not None -pytestmark = pytest.mark.skipif( - not (current_platform.is_rocm() and aiter_available), - reason="AITER ops are only available on ROCm with aiter package installed", -) +if not aiter_available: + pytest.skip("These tests require AITER to run.", allow_module_level=True) def test_rocm_aiter_biased_grouped_topk_custom_op_registration():