[V1] Support MP Executor for multi node distributed inference (#23691)

Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Lucia Fang
2025-11-16 01:01:21 -08:00
committed by GitHub
parent a55b64635c
commit b316ac6589
10 changed files with 930 additions and 82 deletions

View File

@@ -10,7 +10,7 @@ import time
import traceback
import weakref
from collections import deque
from collections.abc import Callable
from collections.abc import Callable, Sequence
from concurrent.futures import Future, InvalidStateError
from contextlib import suppress
from dataclasses import dataclass
@@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import (
get_dcp_group,
get_dp_group,
get_ep_group,
get_inner_dp_world_group,
get_pp_group,
get_tp_group,
)
@@ -90,6 +91,10 @@ class FutureWrapper(Future):
class MultiprocExecutor(Executor):
supports_pp: bool = True
def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True):
self.monitor_workers = monitor_workers
super().__init__(vllm_config)
def _init_executor(self) -> None:
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
@@ -99,6 +104,12 @@ class MultiprocExecutor(Executor):
self.failure_callback: FailureCallback | None = None
self.world_size = self.parallel_config.world_size
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
f"global world_size ({self.parallel_config.world_size}) must be "
f"divisible by nnodes_within_dp "
f"({self.parallel_config.nnodes_within_dp}). "
)
self.local_world_size = self.parallel_config.local_world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
@@ -116,27 +127,37 @@ class MultiprocExecutor(Executor):
distributed_init_method = get_distributed_init_method(
get_loopback_ip(), get_open_port()
)
self.rpc_broadcast_mq: MessageQueue | None = None
scheduler_output_handle: Handle | None = None
# Initialize worker and set up message queues for SchedulerOutputs
# and ModelRunnerOutputs
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
if self.parallel_config.node_rank_within_dp == 0:
# For leader node within each dp rank,
# each dp will have its own leader multiproc executor.
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
self.rpc_broadcast_mq = MessageQueue(
self.world_size,
self.local_world_size,
max_chunk_bytes=max_chunk_bytes,
connect_ip=self.parallel_config.master_addr,
)
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
context = get_mp_context()
shared_worker_lock = context.Lock()
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
for rank in range(self.world_size):
global_start_rank = (
self.local_world_size * self.parallel_config.node_rank_within_dp
)
for local_rank in range(self.local_world_size):
global_rank = global_start_rank + local_rank
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=rank,
rank=rank,
local_rank=local_rank,
rank=global_rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
shared_worker_lock=shared_worker_lock,
@@ -145,15 +166,38 @@ class MultiprocExecutor(Executor):
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
# Wait for all local workers to be ready.
self.workers = WorkerProc.wait_for_ready(unready_workers)
# Start background thread to monitor worker health if not in headless mode.
if self.monitor_workers:
self.start_worker_monitor()
self.response_mqs = []
# Only leader node have remote response mqs
if self.parallel_config.node_rank_within_dp == 0:
for rank in range(self.world_size):
if rank < self.local_world_size:
local_message_queue = self.workers[rank].worker_response_mq
assert local_message_queue is not None
self.response_mqs.append(local_message_queue)
else:
remote_message_queue = self.workers[0].peer_worker_response_mqs[
rank
]
assert remote_message_queue is not None
self.response_mqs.append(remote_message_queue)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
self.start_worker_monitor()
# Wait for all input mqs to be ready.
if self.rpc_broadcast_mq is not None:
self.rpc_broadcast_mq.wait_until_ready()
# Wait for all remote response mqs to be ready.
for response_mq in self.response_mqs:
response_mq.wait_until_ready()
success = True
finally:
if not success:
@@ -168,7 +212,7 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank()
def start_worker_monitor(self):
def start_worker_monitor(self, inline=False) -> None:
workers = self.workers
self_ref = weakref.ref(self)
@@ -192,9 +236,13 @@ class MultiprocExecutor(Executor):
_self.failure_callback = None
callback()
Thread(
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
).start()
if not inline:
Thread(
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
).start()
return
monitor_workers()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
@@ -247,7 +295,9 @@ class MultiprocExecutor(Executor):
) -> Any | list[Any] | Future[Any | list[Any]]:
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
is provided, otherwise list."""
assert self.rpc_broadcast_mq is not None, (
"collective_rpc should not be called on follower node"
)
if self.is_failed:
raise RuntimeError("Executor failed.")
@@ -269,20 +319,20 @@ class MultiprocExecutor(Executor):
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
workers = (
(self.workers[output_rank],) if output_rank is not None else self.workers
)
response_mqs: Sequence[MessageQueue] = self.response_mqs
if output_rank is not None:
response_mqs = (response_mqs[output_rank],)
shutdown_event = self.shutdown_event
def get_response():
responses = []
for w in workers:
for mq in response_mqs:
dequeue_timeout = (
None if deadline is None else (deadline - time.monotonic())
)
try:
status, result = w.worker_response_mq.dequeue(
status, result = mq.dequeue(
timeout=dequeue_timeout, cancel=shutdown_event
)
except TimeoutError as e:
@@ -391,17 +441,26 @@ class UnreadyWorkerProcHandle:
class WorkerProcHandle:
proc: BaseProcess
rank: int
worker_response_mq: MessageQueue # The worker process writes to this MQ
# The worker process writes to this MQ in single-node mode
worker_response_mq: MessageQueue | None
# This is only non empty on driver node,
# the peer worker process i writes to MQ
# `peer_worker_response_mqs[i]`
peer_worker_response_mqs: list[MessageQueue | None]
death_writer: Connection | None = None
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
cls,
unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue | None,
peer_worker_response_mqs: list[MessageQueue | None],
) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
peer_worker_response_mqs=peer_worker_response_mqs,
death_writer=unready_handle.death_writer,
)
@@ -411,6 +470,38 @@ class WorkerProc:
READY_STR = "READY"
def _init_message_queues(
self, input_shm_handle: Handle, vllm_config: VllmConfig
) -> None:
if vllm_config.parallel_config.nnodes_within_dp == 1:
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank
)
# Initializes a message queue for sending the model output
self.worker_response_mq: MessageQueue = MessageQueue(1, 1)
self.peer_response_handles = []
else:
# Initialize remote MessageQueue for receiving SchedulerOutput across nodes
self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
external_writer_handle=input_shm_handle,
# Since there is external_writer_handle from executor proc,
# where the ready signal from actual writer is sent out of the
# create_mq_broadcaster method and after this setup, we make it
# non blocking. The handshake will be triggered when
# worker.rpc_broadcast_mq.wait_until_ready() is called
blocking=False,
)
# Initializes remote message queue for sending the model output to the
# driver worker, exposing peer_response_handles for driver worker
# that include handles for all ranks
self.worker_response_mq, self.peer_response_handles = (
get_inner_dp_world_group().create_single_reader_mq_broadcasters(
reader_rank_in_group=0
)
)
def __init__(
self,
vllm_config: VllmConfig,
@@ -421,13 +512,15 @@ class WorkerProc:
shared_worker_lock: LockType,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
wrapper = WorkerWrapperBase(
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
]
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
all_kwargs[rank] = {
all_kwargs[local_rank] = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
@@ -438,14 +531,6 @@ class WorkerProc:
wrapper.init_worker(all_kwargs)
self.worker = wrapper
# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank
)
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
scheduler_config = vllm_config.scheduler_config
self.use_async_scheduling = scheduler_config.async_scheduling
if self.use_async_scheduling:
@@ -466,6 +551,7 @@ class WorkerProc:
)
# Load model
self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model()
# Enable environment variable cache (e.g. assume no more
@@ -512,6 +598,27 @@ class WorkerProc:
# death_reader in child will get EOFError
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
@staticmethod
def wait_for_response_handle_ready(
handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
) -> WorkerProcHandle:
response_handle = handles["handle"]
worker_response_mq: MessageQueue | None = None
if len(response_handle.local_reader_ranks) > 0:
worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
peer_response_handles = handles["peer_response_handles"]
peer_worker_response_mqs = [
MessageQueue.create_from_handle(handle, -1)
if handle.remote_subscribe_addr is not None
else None
for handle in peer_response_handles
]
return WorkerProcHandle.from_unready_handle(
proc_handle,
worker_response_mq,
peer_worker_response_mqs=peer_worker_response_mqs,
)
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle],
@@ -537,16 +644,10 @@ class WorkerProc:
if response["status"] != "READY":
raise e
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0
idx = unready_proc_handle.rank % len(ready_proc_handles)
ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
response, unready_proc_handle
)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq
)
)
except EOFError:
e.__suppress_context__ = True
raise e from None
@@ -618,12 +719,14 @@ class WorkerProc:
{
"status": WorkerProc.READY_STR,
"handle": worker.worker_response_mq.export_handle(),
"peer_response_handles": worker.peer_response_handles,
}
)
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
if worker.rpc_broadcast_mq is not None:
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None