From 69f8a0ea37a018b2e54897c44f8832de51a2bf5c Mon Sep 17 00:00:00 2001 From: Rabi Mishra Date: Wed, 14 Jan 2026 00:41:54 +0530 Subject: [PATCH] fix(rocm): Use refresh_env_variables() for rocm_aiter_ops in test_moe (#31711) Signed-off-by: rabi --- tests/kernels/moe/test_moe.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index b58c42b7d..2e2581fec 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -6,8 +6,6 @@ Run `pytest tests/kernels/test_moe.py`. """ import functools -import importlib -import sys from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -592,15 +590,13 @@ def test_mixtral_moe( """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" - # clear the cache before every test - # Force reload aiter_ops to pick up the new environment variables. - if "rocm_aiter_ops" in sys.modules: - importlib.reload(rocm_aiter_ops) + # Explicitly set AITER env var based on test parameter to ensure + # consistent behavior regardless of external environment + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0") + rocm_aiter_ops.refresh_env_variables() - if use_rocm_aiter: - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - if dtype == torch.float32: - pytest.skip("AITER ROCm test skip for float32") + if use_rocm_aiter and dtype == torch.float32: + pytest.skip("AITER ROCm test skip for float32") monkeypatch.setenv("RANK", "0") monkeypatch.setenv("LOCAL_RANK", "0")