Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
395 lines
12 KiB
Python
395 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import random
|
|
import threading
|
|
import time
|
|
from unittest import mock
|
|
|
|
import multiprocess as mp
|
|
import numpy as np
|
|
import pytest
|
|
import torch.distributed as dist
|
|
|
|
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
|
from vllm.distributed.utils import StatelessProcessGroup
|
|
from vllm.utils.network_utils import get_open_port
|
|
from vllm.utils.system_utils import update_environment_variables
|
|
|
|
|
|
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
|
np.random.seed(seed)
|
|
sizes = np.random.randint(1, 10_000, n)
|
|
# on average, each array will have 5k elements
|
|
# with int64, each array will have 40kb
|
|
return [np.random.randint(1, 100, i) for i in sizes]
|
|
|
|
|
|
def distributed_run(fn, world_size, timeout=60):
|
|
"""Run a function in multiple processes with proper error handling.
|
|
|
|
Args:
|
|
fn: Function to run in each process
|
|
world_size: Number of processes to spawn
|
|
timeout: Maximum time in seconds to wait for processes (default: 60)
|
|
"""
|
|
number_of_processes = world_size
|
|
processes = []
|
|
for i in range(number_of_processes):
|
|
env = {}
|
|
env["RANK"] = str(i)
|
|
env["LOCAL_RANK"] = str(i)
|
|
env["WORLD_SIZE"] = str(number_of_processes)
|
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
|
env["MASTER_ADDR"] = "localhost"
|
|
env["MASTER_PORT"] = "12345"
|
|
p = mp.Process(target=fn, args=(env,))
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
# Monitor processes and fail fast if any process fails
|
|
start_time = time.time()
|
|
failed_processes = []
|
|
|
|
# Wait for all processes, checking for failures
|
|
while time.time() - start_time < timeout:
|
|
all_done = True
|
|
for i, p in enumerate(processes):
|
|
if p.is_alive():
|
|
all_done = False
|
|
elif p.exitcode != 0:
|
|
# Process failed
|
|
failed_processes.append((i, p.exitcode))
|
|
break
|
|
|
|
if failed_processes or all_done:
|
|
break
|
|
time.sleep(0.1) # Check every 100ms
|
|
|
|
# Check for timeout if no failures detected yet
|
|
for i, p in enumerate(processes):
|
|
if p.is_alive():
|
|
p.kill()
|
|
p.join()
|
|
|
|
# Report failures
|
|
if failed_processes:
|
|
error_msg = "Distributed test failed:\n"
|
|
for rank, status in failed_processes:
|
|
error_msg += f" Rank {rank}: Exit code {status}\n"
|
|
raise AssertionError(error_msg)
|
|
|
|
|
|
def worker_fn_wrapper(fn):
|
|
# `mp.Process` cannot accept environment variables directly
|
|
# so we need to pass the environment variables as arguments
|
|
# and update the environment variables in the function
|
|
def wrapped_fn(env):
|
|
update_environment_variables(env)
|
|
dist.init_process_group(backend="gloo")
|
|
fn()
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn():
|
|
rank = dist.get_rank()
|
|
if rank == 0:
|
|
port = get_open_port()
|
|
ip = "127.0.0.1"
|
|
dist.broadcast_object_list([ip, port], src=0)
|
|
else:
|
|
recv = [None, None]
|
|
dist.broadcast_object_list(recv, src=0)
|
|
ip, port = recv # type: ignore
|
|
|
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
|
|
|
for pg in [dist.group.WORLD, stateless_pg]:
|
|
writer_rank = 2
|
|
broadcaster = MessageQueue.create_from_process_group(
|
|
pg, 40 * 1024, 2, writer_rank
|
|
)
|
|
if rank == writer_rank:
|
|
seed = random.randint(0, 1000)
|
|
dist.broadcast_object_list([seed], writer_rank)
|
|
else:
|
|
recv = [None]
|
|
dist.broadcast_object_list(recv, writer_rank)
|
|
seed = recv[0] # type: ignore
|
|
|
|
if pg == dist.group.WORLD:
|
|
dist.barrier()
|
|
else:
|
|
pg.barrier()
|
|
|
|
# in case we find a race condition
|
|
# print the seed so that we can reproduce the error
|
|
print(f"Rank {rank} got seed {seed}")
|
|
# test broadcasting with about 400MB of data
|
|
N = 10_000
|
|
if rank == writer_rank:
|
|
arrs = get_arrays(N, seed)
|
|
for x in arrs:
|
|
broadcaster.broadcast_object(x)
|
|
time.sleep(random.random() / 1000)
|
|
else:
|
|
arrs = get_arrays(N, seed)
|
|
for x in arrs:
|
|
y = broadcaster.broadcast_object(None)
|
|
assert np.array_equal(x, y)
|
|
time.sleep(random.random() / 1000)
|
|
|
|
if pg == dist.group.WORLD:
|
|
dist.barrier()
|
|
print(f"torch distributed passed the test! Rank {rank}")
|
|
else:
|
|
pg.barrier()
|
|
print(f"StatelessProcessGroup passed the test! Rank {rank}")
|
|
|
|
|
|
def test_shm_broadcast():
|
|
distributed_run(worker_fn, 4)
|
|
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn_test_shutdown_busy():
|
|
rank = dist.get_rank()
|
|
writer_rank = 2
|
|
message_queue = MessageQueue.create_from_process_group(
|
|
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
|
)
|
|
|
|
if not message_queue._is_writer:
|
|
# Put into busy mode
|
|
message_queue._spin_condition.busy_loop_s = 9999
|
|
|
|
shutdown_event = threading.Event()
|
|
|
|
def shutdown_thread(mq, shutdown_event):
|
|
shutdown_event.wait()
|
|
mq.shutdown()
|
|
|
|
threading.Thread(
|
|
target=shutdown_thread, args=(message_queue, shutdown_event)
|
|
).start()
|
|
|
|
with pytest.raises(TimeoutError):
|
|
message_queue.dequeue(timeout=0.01)
|
|
|
|
shutdown_event.set()
|
|
|
|
with pytest.raises(RuntimeError, match="cancelled"):
|
|
message_queue.dequeue(timeout=1)
|
|
|
|
assert message_queue.shutting_down
|
|
|
|
print(f"torch distributed passed the test! Rank {rank}")
|
|
dist.barrier()
|
|
|
|
|
|
def test_message_queue_shutdown_busy(caplog_vllm):
|
|
distributed_run(worker_fn_test_shutdown_busy, 4)
|
|
print(caplog_vllm.text)
|
|
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn_test_shutdown_idle():
|
|
rank = dist.get_rank()
|
|
writer_rank = 2
|
|
message_queue = MessageQueue.create_from_process_group(
|
|
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
|
)
|
|
|
|
if not message_queue._is_writer:
|
|
# Put into idle mode
|
|
message_queue._spin_condition.last_read = 0
|
|
|
|
shutdown_event = threading.Event()
|
|
|
|
def shutdown_thread(mq, shutdown_event):
|
|
shutdown_event.wait()
|
|
mq.shutdown()
|
|
|
|
threading.Thread(
|
|
target=shutdown_thread, args=(message_queue, shutdown_event)
|
|
).start()
|
|
|
|
with pytest.raises(TimeoutError):
|
|
message_queue.dequeue(timeout=0.01)
|
|
|
|
shutdown_event.set()
|
|
|
|
with pytest.raises(RuntimeError, match="cancelled"):
|
|
message_queue.dequeue(timeout=1)
|
|
|
|
assert message_queue.shutting_down
|
|
|
|
print(f"torch distributed passed the test! Rank {rank}")
|
|
dist.barrier()
|
|
|
|
|
|
def test_message_queue_shutdown_idle():
|
|
distributed_run(worker_fn_test_shutdown_idle, 4)
|
|
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn_test_idle_to_busy():
|
|
rank = dist.get_rank()
|
|
writer_rank = 2
|
|
message_queue = MessageQueue.create_from_process_group(
|
|
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
|
)
|
|
|
|
message1 = "hello world"
|
|
message2 = np.random.randint(1, 100, 100)
|
|
with mock.patch.object(
|
|
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
|
|
) as wrapped_wait:
|
|
if not message_queue._is_writer:
|
|
# Put into idle mode
|
|
message_queue._spin_condition.last_read = 0
|
|
|
|
# no messages, so expect a TimeoutError
|
|
with pytest.raises(TimeoutError):
|
|
message_queue.dequeue(timeout=0.01)
|
|
# wait should only be called once while idle
|
|
assert wrapped_wait.call_count == 1
|
|
|
|
# sync with the writer and wait for message1
|
|
dist.barrier()
|
|
recv_message = message_queue.dequeue(timeout=5)
|
|
assert recv_message == message1
|
|
# second call to wait, with a message read, this puts in a busy spin
|
|
assert wrapped_wait.call_count == 2
|
|
|
|
# sync with the writer and wait for message2
|
|
dist.barrier()
|
|
recv_message = message_queue.dequeue(timeout=1)
|
|
assert np.array_equal(recv_message, message2)
|
|
# in busy mode, we expect wait to have been called multiple times
|
|
assert wrapped_wait.call_count > 3
|
|
else:
|
|
# writer writes two messages in sync with the reader
|
|
dist.barrier()
|
|
# sleep delays the send to ensure reader enters the read loop
|
|
time.sleep(0.1)
|
|
message_queue.enqueue(message1)
|
|
|
|
dist.barrier()
|
|
time.sleep(0.1)
|
|
message_queue.enqueue(message2)
|
|
|
|
message_queue.shutdown()
|
|
assert message_queue.shutting_down
|
|
print(f"torch distributed passed the test! Rank {rank}")
|
|
|
|
|
|
def test_message_queue_idle_wake():
|
|
distributed_run(worker_fn_test_idle_to_busy, 4)
|
|
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn_test_busy_to_idle():
|
|
rank = dist.get_rank()
|
|
writer_rank = 2
|
|
message_queue = MessageQueue.create_from_process_group(
|
|
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
|
)
|
|
|
|
message1 = 12345
|
|
message2 = list(range(3))
|
|
with mock.patch.object(
|
|
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
|
|
) as wrapped_wait:
|
|
if not message_queue._is_writer:
|
|
# Put into busy mode
|
|
message_queue._spin_condition.busy_loop_s = 9999
|
|
|
|
# sync with the writer and wait for message1
|
|
dist.barrier()
|
|
recv_message = message_queue.dequeue(timeout=1)
|
|
assert recv_message == message1
|
|
# in busy mode, we expect wait to have been called many times
|
|
assert wrapped_wait.call_count > 1
|
|
|
|
# simulate busy loop ending
|
|
message_queue._spin_condition.busy_loop_s = 0
|
|
# ensure we enter idle mode, then record call count
|
|
with pytest.raises(TimeoutError):
|
|
message_queue.dequeue(timeout=0.01)
|
|
call_count = wrapped_wait.call_count
|
|
|
|
# sync with the writer and wait for message2
|
|
dist.barrier()
|
|
recv_message = message_queue.dequeue(timeout=1)
|
|
assert recv_message == message2
|
|
|
|
# call to wait after idle should only happen once
|
|
assert wrapped_wait.call_count == call_count + 1
|
|
else:
|
|
# writer writes two messages in sync with the reader
|
|
dist.barrier()
|
|
# sleep delays the send to ensure reader enters the read loop
|
|
time.sleep(0.1)
|
|
message_queue.enqueue(message1)
|
|
|
|
dist.barrier()
|
|
time.sleep(0.1)
|
|
message_queue.enqueue(message2)
|
|
|
|
message_queue.shutdown()
|
|
assert message_queue.shutting_down
|
|
print(f"torch distributed passed the test! Rank {rank}")
|
|
|
|
|
|
def test_message_queue_busy_to_idle():
|
|
distributed_run(worker_fn_test_busy_to_idle, 4)
|
|
|
|
|
|
def test_warning_logs(caplog_vllm):
|
|
"""
|
|
Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals
|
|
when indefinite=False, and are not emitted when indefinite=True.
|
|
"""
|
|
|
|
# Patch the warning log interval to every 1 ms during reads
|
|
with mock.patch(
|
|
"vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL",
|
|
new=0.001, # 1 ms
|
|
):
|
|
writer = MessageQueue(
|
|
n_reader=1,
|
|
n_local_reader=1,
|
|
max_chunk_bytes=1024 * 1024, # 1MB chunks
|
|
max_chunks=10,
|
|
)
|
|
reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0)
|
|
writer.wait_until_ready()
|
|
reader.wait_until_ready()
|
|
|
|
# We should have at least one warning log here
|
|
# "0 seconds" expected due to rounding of 1ms test interval
|
|
with pytest.raises(TimeoutError):
|
|
reader.dequeue(timeout=0.01, indefinite=False)
|
|
assert any(
|
|
"No available shared memory broadcast block found in 0 seconds"
|
|
in record.message
|
|
for record in caplog_vllm.records
|
|
)
|
|
caplog_vllm.clear()
|
|
|
|
# We should have no warnings this time
|
|
with pytest.raises(TimeoutError):
|
|
reader.dequeue(timeout=0.01, indefinite=True)
|
|
assert all(
|
|
"No available shared memory broadcast block found in 0 seconds"
|
|
not in record.message
|
|
for record in caplog_vllm.records
|
|
)
|
|
|
|
# Clean up when done
|
|
writer.shutdown()
|
|
reader.shutdown()
|