[V1] Support MP Executor for multi node distributed inference (#23691)
Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import Future, InvalidStateError
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
@@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_inner_dp_world_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
@@ -90,6 +91,10 @@ class FutureWrapper(Future):
|
||||
class MultiprocExecutor(Executor):
|
||||
supports_pp: bool = True
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True):
|
||||
self.monitor_workers = monitor_workers
|
||||
super().__init__(vllm_config)
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
@@ -99,6 +104,12 @@ class MultiprocExecutor(Executor):
|
||||
self.failure_callback: FailureCallback | None = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||
f"divisible by nnodes_within_dp "
|
||||
f"({self.parallel_config.nnodes_within_dp}). "
|
||||
)
|
||||
self.local_world_size = self.parallel_config.local_world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
@@ -116,27 +127,37 @@ class MultiprocExecutor(Executor):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_loopback_ip(), get_open_port()
|
||||
)
|
||||
|
||||
self.rpc_broadcast_mq: MessageQueue | None = None
|
||||
scheduler_output_handle: Handle | None = None
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
if self.parallel_config.node_rank_within_dp == 0:
|
||||
# For leader node within each dp rank,
|
||||
# each dp will have its own leader multiproc executor.
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size,
|
||||
self.local_world_size,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
connect_ip=self.parallel_config.master_addr,
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
# Create workers
|
||||
context = get_mp_context()
|
||||
shared_worker_lock = context.Lock()
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
global_start_rank = (
|
||||
self.local_world_size * self.parallel_config.node_rank_within_dp
|
||||
)
|
||||
for local_rank in range(self.local_world_size):
|
||||
global_rank = global_start_rank + local_rank
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
rank=global_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
@@ -145,15 +166,38 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
|
||||
# Wait for all local workers to be ready.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Start background thread to monitor worker health if not in headless mode.
|
||||
if self.monitor_workers:
|
||||
self.start_worker_monitor()
|
||||
|
||||
self.response_mqs = []
|
||||
# Only leader node have remote response mqs
|
||||
if self.parallel_config.node_rank_within_dp == 0:
|
||||
for rank in range(self.world_size):
|
||||
if rank < self.local_world_size:
|
||||
local_message_queue = self.workers[rank].worker_response_mq
|
||||
assert local_message_queue is not None
|
||||
self.response_mqs.append(local_message_queue)
|
||||
else:
|
||||
remote_message_queue = self.workers[0].peer_worker_response_mqs[
|
||||
rank
|
||||
]
|
||||
assert remote_message_queue is not None
|
||||
self.response_mqs.append(remote_message_queue)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
# Wait for all input mqs to be ready.
|
||||
if self.rpc_broadcast_mq is not None:
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
# Wait for all remote response mqs to be ready.
|
||||
for response_mq in self.response_mqs:
|
||||
response_mq.wait_until_ready()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
@@ -168,7 +212,7 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
|
||||
def start_worker_monitor(self):
|
||||
def start_worker_monitor(self, inline=False) -> None:
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
|
||||
@@ -192,9 +236,13 @@ class MultiprocExecutor(Executor):
|
||||
_self.failure_callback = None
|
||||
callback()
|
||||
|
||||
Thread(
|
||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||
).start()
|
||||
if not inline:
|
||||
Thread(
|
||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||
).start()
|
||||
return
|
||||
|
||||
monitor_workers()
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
if self.is_failed:
|
||||
@@ -247,7 +295,9 @@ class MultiprocExecutor(Executor):
|
||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
|
||||
is provided, otherwise list."""
|
||||
|
||||
assert self.rpc_broadcast_mq is not None, (
|
||||
"collective_rpc should not be called on follower node"
|
||||
)
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
@@ -269,20 +319,20 @@ class MultiprocExecutor(Executor):
|
||||
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
|
||||
|
||||
workers = (
|
||||
(self.workers[output_rank],) if output_rank is not None else self.workers
|
||||
)
|
||||
response_mqs: Sequence[MessageQueue] = self.response_mqs
|
||||
if output_rank is not None:
|
||||
response_mqs = (response_mqs[output_rank],)
|
||||
|
||||
shutdown_event = self.shutdown_event
|
||||
|
||||
def get_response():
|
||||
responses = []
|
||||
for w in workers:
|
||||
for mq in response_mqs:
|
||||
dequeue_timeout = (
|
||||
None if deadline is None else (deadline - time.monotonic())
|
||||
)
|
||||
try:
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
status, result = mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=shutdown_event
|
||||
)
|
||||
except TimeoutError as e:
|
||||
@@ -391,17 +441,26 @@ class UnreadyWorkerProcHandle:
|
||||
class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
# The worker process writes to this MQ in single-node mode
|
||||
worker_response_mq: MessageQueue | None
|
||||
# This is only non empty on driver node,
|
||||
# the peer worker process i writes to MQ
|
||||
# `peer_worker_response_mqs[i]`
|
||||
peer_worker_response_mqs: list[MessageQueue | None]
|
||||
death_writer: Connection | None = None
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
|
||||
cls,
|
||||
unready_handle: UnreadyWorkerProcHandle,
|
||||
worker_response_mq: MessageQueue | None,
|
||||
peer_worker_response_mqs: list[MessageQueue | None],
|
||||
) -> "WorkerProcHandle":
|
||||
return cls(
|
||||
proc=unready_handle.proc,
|
||||
rank=unready_handle.rank,
|
||||
worker_response_mq=worker_response_mq,
|
||||
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||
death_writer=unready_handle.death_writer,
|
||||
)
|
||||
|
||||
@@ -411,6 +470,38 @@ class WorkerProc:
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def _init_message_queues(
|
||||
self, input_shm_handle: Handle, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
if vllm_config.parallel_config.nnodes_within_dp == 1:
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank
|
||||
)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq: MessageQueue = MessageQueue(1, 1)
|
||||
self.peer_response_handles = []
|
||||
else:
|
||||
# Initialize remote MessageQueue for receiving SchedulerOutput across nodes
|
||||
self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
|
||||
external_writer_handle=input_shm_handle,
|
||||
# Since there is external_writer_handle from executor proc,
|
||||
# where the ready signal from actual writer is sent out of the
|
||||
# create_mq_broadcaster method and after this setup, we make it
|
||||
# non blocking. The handshake will be triggered when
|
||||
# worker.rpc_broadcast_mq.wait_until_ready() is called
|
||||
blocking=False,
|
||||
)
|
||||
# Initializes remote message queue for sending the model output to the
|
||||
# driver worker, exposing peer_response_handles for driver worker
|
||||
# that include handles for all ranks
|
||||
self.worker_response_mq, self.peer_response_handles = (
|
||||
get_inner_dp_world_group().create_single_reader_mq_broadcasters(
|
||||
reader_rank_in_group=0
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -421,13 +512,15 @@ class WorkerProc:
|
||||
shared_worker_lock: LockType,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
wrapper = WorkerWrapperBase(
|
||||
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
|
||||
)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
||||
all_kwargs[rank] = {
|
||||
all_kwargs[local_rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
@@ -438,14 +531,6 @@ class WorkerProc:
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper
|
||||
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank
|
||||
)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||
if self.use_async_scheduling:
|
||||
@@ -466,6 +551,7 @@ class WorkerProc:
|
||||
)
|
||||
|
||||
# Load model
|
||||
self._init_message_queues(input_shm_handle, vllm_config)
|
||||
self.worker.load_model()
|
||||
|
||||
# Enable environment variable cache (e.g. assume no more
|
||||
@@ -512,6 +598,27 @@ class WorkerProc:
|
||||
# death_reader in child will get EOFError
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_response_handle_ready(
|
||||
handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
|
||||
) -> WorkerProcHandle:
|
||||
response_handle = handles["handle"]
|
||||
worker_response_mq: MessageQueue | None = None
|
||||
if len(response_handle.local_reader_ranks) > 0:
|
||||
worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
|
||||
peer_response_handles = handles["peer_response_handles"]
|
||||
peer_worker_response_mqs = [
|
||||
MessageQueue.create_from_handle(handle, -1)
|
||||
if handle.remote_subscribe_addr is not None
|
||||
else None
|
||||
for handle in peer_response_handles
|
||||
]
|
||||
return WorkerProcHandle.from_unready_handle(
|
||||
proc_handle,
|
||||
worker_response_mq,
|
||||
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_ready(
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle],
|
||||
@@ -537,16 +644,10 @@ class WorkerProc:
|
||||
if response["status"] != "READY":
|
||||
raise e
|
||||
|
||||
# Extract the message queue handle.
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
response["handle"], 0
|
||||
idx = unready_proc_handle.rank % len(ready_proc_handles)
|
||||
ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
|
||||
response, unready_proc_handle
|
||||
)
|
||||
ready_proc_handles[unready_proc_handle.rank] = (
|
||||
WorkerProcHandle.from_unready_handle(
|
||||
unready_proc_handle, worker_response_mq
|
||||
)
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
e.__suppress_context__ = True
|
||||
raise e from None
|
||||
@@ -618,12 +719,14 @@ class WorkerProc:
|
||||
{
|
||||
"status": WorkerProc.READY_STR,
|
||||
"handle": worker.worker_response_mq.export_handle(),
|
||||
"peer_response_handles": worker.peer_response_handles,
|
||||
}
|
||||
)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
if worker.rpc_broadcast_mq is not None:
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
worker.worker_response_mq.wait_until_ready()
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
Reference in New Issue
Block a user