[ROCm][CI] Fix AITER state leak in shared_fused_moe_routed_transform test (#38137)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -125,8 +125,15 @@ def test_routing_strategy_integration(monkeypatch, device):
|
||||
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
|
||||
monkeypatch.setenv(env_name, strategy)
|
||||
|
||||
# Force reload of environment variable
|
||||
envs.environment_variables[env_name] = lambda s=strategy: s
|
||||
# Temporarily override the envs lookup so the router factory
|
||||
# reads the monkeypatched value instead of the module-load-time
|
||||
# default. Use monkeypatch.setitem so the original lambda is
|
||||
# restored automatically at teardown.
|
||||
monkeypatch.setitem(
|
||||
envs.environment_variables,
|
||||
env_name,
|
||||
lambda s=strategy: s,
|
||||
)
|
||||
|
||||
# Test the select_experts method
|
||||
topk_weights, topk_ids = fused_moe.router.select_experts(
|
||||
|
||||
@@ -137,7 +137,7 @@ def test_routed_input_transform_inside_vs_outside(
|
||||
Method A (inside): SharedFusedMoE with routed_input_transform
|
||||
Method B (outside): Manually transform, then SharedFusedMoE without transform
|
||||
"""
|
||||
if current_platform.is_rocm() and use_rocm_aiter:
|
||||
if current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_MOE", "1" if use_rocm_aiter else "0")
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@@ -1905,6 +1905,17 @@ def destroy_distributed_environment():
|
||||
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
# Reset environment variable cache
|
||||
envs.disable_envs_cache()
|
||||
|
||||
# Reset rocm_aiter_ops class variables to match current os.environ.
|
||||
# These are class-level attributes that persist across tests and are
|
||||
# NOT restored by monkeypatch (which only restores os.environ).
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
# Ensure all objects are not frozen before cleanup
|
||||
gc.unfreeze()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user