fix(rocm): Use refresh_env_variables() for rocm_aiter_ops in test_moe (#31711)
Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
@@ -6,8 +6,6 @@ Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
@@ -592,15 +590,13 @@ def test_mixtral_moe(
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
# clear the cache before every test
|
||||
# Force reload aiter_ops to pick up the new environment variables.
|
||||
if "rocm_aiter_ops" in sys.modules:
|
||||
importlib.reload(rocm_aiter_ops)
|
||||
# Explicitly set AITER env var based on test parameter to ensure
|
||||
# consistent behavior regardless of external environment
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
if dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
if use_rocm_aiter and dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
monkeypatch.setenv("RANK", "0")
|
||||
monkeypatch.setenv("LOCAL_RANK", "0")
|
||||
|
||||
Reference in New Issue
Block a user