[Core] Remove busy loop from idle buffer readers (#28053)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import multiprocessing
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
import multiprocess as mp
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
@@ -22,7 +25,14 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
||||
return [np.random.randint(1, 100, i) for i in sizes]
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
def distributed_run(fn, world_size, timeout=60):
|
||||
"""Run a function in multiple processes with proper error handling.
|
||||
|
||||
Args:
|
||||
fn: Function to run in each process
|
||||
world_size: Number of processes to spawn
|
||||
timeout: Maximum time in seconds to wait for processes (default: 60)
|
||||
"""
|
||||
number_of_processes = world_size
|
||||
processes = []
|
||||
for i in range(number_of_processes):
|
||||
@@ -33,19 +43,45 @@ def distributed_run(fn, world_size):
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = multiprocessing.Process(target=fn, args=(env,))
|
||||
p = mp.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
# Monitor processes and fail fast if any process fails
|
||||
start_time = time.time()
|
||||
failed_processes = []
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
# Wait for all processes, checking for failures
|
||||
while time.time() - start_time < timeout:
|
||||
all_done = True
|
||||
for i, p in enumerate(processes):
|
||||
if p.is_alive():
|
||||
all_done = False
|
||||
elif p.exitcode != 0:
|
||||
# Process failed
|
||||
failed_processes.append((i, p.exitcode))
|
||||
break
|
||||
|
||||
if failed_processes or all_done:
|
||||
break
|
||||
time.sleep(0.1) # Check every 100ms
|
||||
|
||||
# Check for timeout if no failures detected yet
|
||||
for i, p in enumerate(processes):
|
||||
if p.is_alive():
|
||||
p.kill()
|
||||
p.join()
|
||||
|
||||
# Report failures
|
||||
if failed_processes:
|
||||
error_msg = "Distributed test failed:\n"
|
||||
for rank, status in failed_processes:
|
||||
error_msg += f" Rank {rank}: Exit code {status}\n"
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# `mp.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
@@ -115,3 +151,244 @@ def worker_fn():
|
||||
|
||||
def test_shm_broadcast():
|
||||
distributed_run(worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_shutdown_busy():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
if not message_queue._is_writer:
|
||||
# Put into busy mode
|
||||
message_queue._spin_condition.busy_loop_s = 9999
|
||||
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
def shutdown_thread(mq, shutdown_event):
|
||||
shutdown_event.wait()
|
||||
mq.shutdown()
|
||||
|
||||
threading.Thread(
|
||||
target=shutdown_thread, args=(message_queue, shutdown_event)
|
||||
).start()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
|
||||
shutdown_event.set()
|
||||
|
||||
with pytest.raises(RuntimeError, match="cancelled"):
|
||||
message_queue.dequeue(timeout=1)
|
||||
|
||||
assert message_queue.shutting_down
|
||||
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def test_message_queue_shutdown_busy(caplog_vllm):
|
||||
distributed_run(worker_fn_test_shutdown_busy, 4)
|
||||
print(caplog_vllm.text)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_shutdown_idle():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
if not message_queue._is_writer:
|
||||
# Put into idle mode
|
||||
message_queue._spin_condition.last_read = 0
|
||||
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
def shutdown_thread(mq, shutdown_event):
|
||||
shutdown_event.wait()
|
||||
mq.shutdown()
|
||||
|
||||
threading.Thread(
|
||||
target=shutdown_thread, args=(message_queue, shutdown_event)
|
||||
).start()
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
|
||||
shutdown_event.set()
|
||||
|
||||
with pytest.raises(RuntimeError, match="cancelled"):
|
||||
message_queue.dequeue(timeout=1)
|
||||
|
||||
assert message_queue.shutting_down
|
||||
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def test_message_queue_shutdown_idle():
|
||||
distributed_run(worker_fn_test_shutdown_idle, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_idle_to_busy():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
message1 = "hello world"
|
||||
message2 = np.random.randint(1, 100, 100)
|
||||
with mock.patch.object(
|
||||
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
|
||||
) as wrapped_wait:
|
||||
if not message_queue._is_writer:
|
||||
# Put into idle mode
|
||||
message_queue._spin_condition.last_read = 0
|
||||
|
||||
# no messages, so expect a TimeoutError
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
# wait should only be called once while idle
|
||||
assert wrapped_wait.call_count == 1
|
||||
|
||||
# sync with the writer and wait for message1
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=5)
|
||||
assert recv_message == message1
|
||||
# second call to wait, with a message read, this puts in a busy spin
|
||||
assert wrapped_wait.call_count == 2
|
||||
|
||||
# sync with the writer and wait for message2
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=1)
|
||||
assert np.array_equal(recv_message, message2)
|
||||
# in busy mode, we expect wait to have been called multiple times
|
||||
assert wrapped_wait.call_count > 3
|
||||
else:
|
||||
# writer writes two messages in sync with the reader
|
||||
dist.barrier()
|
||||
# sleep delays the send to ensure reader enters the read loop
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message1)
|
||||
|
||||
dist.barrier()
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message2)
|
||||
|
||||
message_queue.shutdown()
|
||||
assert message_queue.shutting_down
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
|
||||
|
||||
def test_message_queue_idle_wake():
|
||||
distributed_run(worker_fn_test_idle_to_busy, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_test_busy_to_idle():
|
||||
rank = dist.get_rank()
|
||||
writer_rank = 2
|
||||
message_queue = MessageQueue.create_from_process_group(
|
||||
dist.group.WORLD, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
|
||||
message1 = 12345
|
||||
message2 = list(range(3))
|
||||
with mock.patch.object(
|
||||
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
|
||||
) as wrapped_wait:
|
||||
if not message_queue._is_writer:
|
||||
# Put into busy mode
|
||||
message_queue._spin_condition.busy_loop_s = 9999
|
||||
|
||||
# sync with the writer and wait for message1
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=1)
|
||||
assert recv_message == message1
|
||||
# in busy mode, we expect wait to have been called many times
|
||||
assert wrapped_wait.call_count > 1
|
||||
|
||||
# simulate busy loop ending
|
||||
message_queue._spin_condition.busy_loop_s = 0
|
||||
# ensure we enter idle mode, then record call count
|
||||
with pytest.raises(TimeoutError):
|
||||
message_queue.dequeue(timeout=0.01)
|
||||
call_count = wrapped_wait.call_count
|
||||
|
||||
# sync with the writer and wait for message2
|
||||
dist.barrier()
|
||||
recv_message = message_queue.dequeue(timeout=1)
|
||||
assert recv_message == message2
|
||||
|
||||
# call to wait after idle should only happen once
|
||||
assert wrapped_wait.call_count == call_count + 1
|
||||
else:
|
||||
# writer writes two messages in sync with the reader
|
||||
dist.barrier()
|
||||
# sleep delays the send to ensure reader enters the read loop
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message1)
|
||||
|
||||
dist.barrier()
|
||||
time.sleep(0.1)
|
||||
message_queue.enqueue(message2)
|
||||
|
||||
message_queue.shutdown()
|
||||
assert message_queue.shutting_down
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
|
||||
|
||||
def test_message_queue_busy_to_idle():
|
||||
distributed_run(worker_fn_test_busy_to_idle, 4)
|
||||
|
||||
|
||||
def test_warning_logs(caplog_vllm):
|
||||
"""
|
||||
Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals
|
||||
when indefinite=False, and are not emitted when indefinite=True.
|
||||
"""
|
||||
|
||||
# Patch the warning log interval to every 1 ms during reads
|
||||
with mock.patch(
|
||||
"vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL",
|
||||
new=0.001, # 1 ms
|
||||
):
|
||||
writer = MessageQueue(
|
||||
n_reader=1,
|
||||
n_local_reader=1,
|
||||
max_chunk_bytes=1024 * 1024, # 1MB chunks
|
||||
max_chunks=10,
|
||||
)
|
||||
reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0)
|
||||
writer.wait_until_ready()
|
||||
reader.wait_until_ready()
|
||||
|
||||
# We should have at least one warning log here
|
||||
# "0 seconds" expected due to rounding of 1ms test interval
|
||||
with pytest.raises(TimeoutError):
|
||||
reader.dequeue(timeout=0.01, indefinite=False)
|
||||
assert any(
|
||||
"No available shared memory broadcast block found in 0 seconds"
|
||||
in record.message
|
||||
for record in caplog_vllm.records
|
||||
)
|
||||
caplog_vllm.clear()
|
||||
|
||||
# We should have no warnings this time
|
||||
with pytest.raises(TimeoutError):
|
||||
reader.dequeue(timeout=0.01, indefinite=True)
|
||||
assert all(
|
||||
"No available shared memory broadcast block found in 0 seconds"
|
||||
not in record.message
|
||||
for record in caplog_vllm.records
|
||||
)
|
||||
|
||||
# Clean up when done
|
||||
writer.shutdown()
|
||||
reader.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user