diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index b58c42b7d..2e2581fec 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -6,8 +6,6 @@ Run `pytest tests/kernels/test_moe.py`. """ import functools -import importlib -import sys from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -592,15 +590,13 @@ def test_mixtral_moe( """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" - # clear the cache before every test - # Force reload aiter_ops to pick up the new environment variables. - if "rocm_aiter_ops" in sys.modules: - importlib.reload(rocm_aiter_ops) + # Explicitly set AITER env var based on test parameter to ensure + # consistent behavior regardless of external environment + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0") + rocm_aiter_ops.refresh_env_variables() - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - if dtype == torch.float32: - pytest.skip("AITER ROCm test skip for float32") + if use_rocm_aiter and dtype == torch.float32: + pytest.skip("AITER ROCm test skip for float32") monkeypatch.setenv("RANK", "0") monkeypatch.setenv("LOCAL_RANK", "0")