[core][distributed] initialization from StatelessProcessGroup (#10986)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-12 01:04:19 -08:00
committed by GitHub
parent 8195824206
commit 62de37a38e
5 changed files with 153 additions and 69 deletions

View File

@@ -7,7 +7,8 @@ import numpy as np
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.utils import update_environment_variables
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port, update_environment_variables
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
@@ -54,34 +55,61 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper
def worker_fn():
writer_rank = 2
broadcaster = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
dist.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if dist.get_rank() == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
dist.barrier()
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]:
writer_rank = 2
broadcaster = MessageQueue.create_from_process_group(
pg, 40 * 1024, 2, writer_rank)
if rank == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
if pg == dist.group.WORLD:
dist.barrier()
else:
pg.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {rank} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if rank == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
if pg == dist.group.WORLD:
dist.barrier()
print("torch distributed passed the test!")
else:
pg.barrier()
print("StatelessProcessGroup passed the test!")
def test_shm_broadcast():