[bugfix][distributed] fix shm broadcast when the queue size is full (#5801)

This commit is contained in:
youkaichao
2024-06-25 21:56:02 -07:00
committed by GitHub
parent 3aa7b6cf66
commit 515080ad2f
2 changed files with 76 additions and 46 deletions

View File

@@ -1,7 +1,9 @@
import multiprocessing
import random
import time
from typing import List
import numpy as np
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import (
@@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import (
from vllm.utils import update_environment_variables
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
np.random.seed(seed)
sizes = np.random.randint(1, 10_000, n)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return [np.random.randint(1, 100, i) for i in sizes]
def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
@@ -47,24 +57,31 @@ def worker_fn_wrapper(fn):
def worker_fn():
writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024, 2, writer_rank)
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
time.sleep(random.random())
broadcaster.broadcast_object(0)
time.sleep(random.random())
broadcaster.broadcast_object({})
time.sleep(random.random())
broadcaster.broadcast_object([])
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
time.sleep(random.random())
a = broadcaster.broadcast_object(None)
time.sleep(random.random())
b = broadcaster.broadcast_object(None)
time.sleep(random.random())
c = broadcaster.broadcast_object(None)
assert a == 0
assert b == {}
assert c == []
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
dist.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {dist.get_rank()} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if dist.get_rank() == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
dist.barrier()