[CI] Fix race condition with StatelessProcessGroup.barrier (#18506)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import get_ip, get_open_port, update_environment_variables
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
|
||||
|
||||
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
||||
@@ -60,12 +60,12 @@ def worker_fn():
|
||||
rank = dist.get_rank()
|
||||
if rank == 0:
|
||||
port = get_open_port()
|
||||
ip = get_ip()
|
||||
ip = '127.0.0.1'
|
||||
dist.broadcast_object_list([ip, port], src=0)
|
||||
else:
|
||||
recv = [None, None]
|
||||
dist.broadcast_object_list(recv, src=0)
|
||||
ip, port = recv
|
||||
ip, port = recv # type: ignore
|
||||
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||
dist.get_world_size())
|
||||
@@ -107,10 +107,10 @@ def worker_fn():
|
||||
|
||||
if pg == dist.group.WORLD:
|
||||
dist.barrier()
|
||||
print("torch distributed passed the test!")
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
else:
|
||||
pg.barrier()
|
||||
print("StatelessProcessGroup passed the test!")
|
||||
print(f"StatelessProcessGroup passed the test! Rank {rank}")
|
||||
|
||||
|
||||
def test_shm_broadcast():
|
||||
|
||||
Reference in New Issue
Block a user