diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index f6e274be9..68abc2b98 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -33,6 +33,7 @@ def graph_allreduce( ): with monkeypatch.context() as m: m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + m.delenv("HIP_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) @@ -92,6 +93,7 @@ def eager_allreduce( ): with monkeypatch.context() as m: m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + m.delenv("HIP_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)