[V1][Frontend] Improve Shutdown And Logs (#11737)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Robert Shaw
2025-04-16 22:48:34 -04:00
committed by GitHub
parent 3c776dcefb
commit 2b05b8ce69
16 changed files with 1031 additions and 347 deletions

View File

@@ -1,21 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import os
import pickle
import signal
import sys
import threading
import time
import traceback
import weakref
from concurrent.futures import Future
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from typing import Any, Callable, Optional, Union
from threading import Thread
from typing import Any, Callable, Optional, Union, cast
import cloudpickle
import psutil
import zmq
from vllm.config import VllmConfig
from vllm.distributed import (destroy_distributed_environment,
@@ -26,8 +28,9 @@ from vllm.executor.multiproc_worker_utils import (
_add_prefix, set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
get_open_port)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@@ -35,6 +38,8 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
EXECUTE_MODEL_TIMEOUT_S = 30
class MultiprocExecutor(Executor):
@@ -42,19 +47,9 @@ class MultiprocExecutor(Executor):
# Call self.shutdown at exit to clean up
# and ensure workers will be terminated.
self._finalizer = weakref.finalize(self, self.shutdown)
# The child processes will send SIGUSR1 when unrecoverable
# errors happen.
def sigusr1_handler(signum, frame):
logger.fatal(
"MulitprocExecutor got fatal signal from worker processes, "
"shutting down. See stack trace above for root cause issue.")
# Propagate error up to parent process.
parent_process = psutil.Process().parent()
parent_process.send_signal(signal.SIGUSR1)
self.shutdown()
signal.signal(signal.SIGUSR1, sigusr1_handler)
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: Optional[FailureCallback] = None
self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@@ -78,28 +73,94 @@ class MultiprocExecutor(Executor):
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
# Create workers
self.workers: list[WorkerProcHandle] = []
for rank in range(self.world_size):
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
rank,
distributed_init_method,
scheduler_output_handle)
self.workers.append(worker)
unready_workers: list[UnreadyWorkerProcHandle] = []
success = False
try:
for rank in range(self.world_size):
unready_workers.append(
WorkerProc.make_worker_process(
vllm_config=self.vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
input_shm_handle=scheduler_output_handle,
))
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
self.workers = WorkerProc.wait_for_ready(unready_workers)
# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
self.rpc_broadcast_mq.wait_until_ready()
for w in self.workers:
w.worker_response_mq.wait_until_ready()
self.start_worker_monitor()
success = True
finally:
if not success:
# Clean up the worker procs if there was a failure.
self._ensure_worker_termination(
[w.proc for w in unready_workers])
def start_worker_monitor(self):
workers = self.workers
self_ref = weakref.ref(self)
# Monitors worker process liveness. If any die unexpectedly,
# logs an error, shuts down the executor and invokes the failure
# callback to inform the engine.
def monitor_workers():
sentinels = [h.proc.sentinel for h in workers]
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or getattr(_self, 'shutting_down', False):
return
_self.is_failed = True
proc_name = next(h.proc.name for h in workers
if h.proc.sentinel == died[0])
logger.error(
"Worker proc %s died unexpectedly, "
"shutting down executor.", proc_name)
_self.shutdown()
callback = _self.failure_callback
if callback is not None:
_self.failure_callback = None
callback()
Thread(target=monitor_workers,
daemon=True,
name="MultiprocWorkerMonitor").start()
def register_failure_callback(self, callback: FailureCallback):
if self.is_failed:
callback()
else:
self.failure_callback = callback
def execute_model(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ),
rank0_reply_only=True,
timeout=EXECUTE_MODEL_TIMEOUT_S)
return output
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
timeout: Optional[float] = 180.0,
args: tuple = (),
kwargs: Optional[dict] = None) -> list[Any]:
kwargs: Optional[dict] = None,
rank0_reply_only: bool = False) -> list[Any]:
start_time = time.monotonic()
kwargs = kwargs or {}
if self.is_failed:
raise RuntimeError("Executor failed.")
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
@@ -109,30 +170,30 @@ class MultiprocExecutor(Executor):
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, rank0_reply_only))
responses = [None] * self.world_size
for w in self.workers:
workers = (self.workers[0], ) if rank0_reply_only else self.workers
responses = [None] * len(workers)
for w in workers:
dequeue_timeout = timeout - (time.monotonic() - start_time
) if timeout is not None else None
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout)
timeout=dequeue_timeout, cancel=self.shutdown_event)
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
"Worker failed with error %s, please check the"
" stack trace above for the root cause", result)
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause")
responses[w.rank] = result
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
except Exception as e:
# Re-raise any other exceptions
raise e
def _ensure_worker_termination(self):
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have
received termination requests. Waits for processing, then sends
termination and kill signals if needed."""
@@ -150,7 +211,7 @@ class MultiprocExecutor(Executor):
return False
# Send SIGTERM if still running
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
active_procs = [proc for proc in worker_procs if proc.is_alive()]
for p in active_procs:
p.terminate()
if not wait_for_termination(active_procs, 4):
@@ -159,22 +220,14 @@ class MultiprocExecutor(Executor):
for p in active_procs:
p.kill()
self._cleanup_sockets()
def _cleanup_sockets(self):
for w in self.workers:
# Remove the zmq ipc socket file
socket_path = w.ready_path.replace("ipc://", "")
if os and os.path.exists(socket_path):
os.remove(socket_path)
def shutdown(self):
"""Properly shut down the executor and its workers"""
if not getattr(self, 'shutting_down', False):
self.shutting_down = True
self.shutdown_event.set()
for w in self.workers:
w.worker_response_mq = None
self._ensure_worker_termination()
self._ensure_worker_termination([w.proc for w in self.workers])
self.rpc_broadcast_mq = None
@@ -183,13 +236,30 @@ class MultiprocExecutor(Executor):
return
@dataclass
class UnreadyWorkerProcHandle:
"""WorkerProcess handle before READY."""
proc: BaseProcess
rank: int
ready_pipe: Connection
@dataclass
class WorkerProcHandle:
proc: BaseProcess
rank: int
ready_path: str
worker_response_mq: MessageQueue # The worker process writes to this MQ
@classmethod
def from_unready_handle(
cls, unready_handle: UnreadyWorkerProcHandle,
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
return cls(
proc=unready_handle.proc,
rank=unready_handle.rank,
worker_response_mq=worker_response_mq,
)
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""
@@ -203,7 +273,6 @@ class WorkerProc:
rank: int,
distributed_init_method: str,
input_shm_handle: Handle,
ready_path: str,
):
self.rank = rank
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
@@ -231,18 +300,8 @@ class WorkerProc:
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
worker_response_mq_handle = self.worker_response_mq.export_handle()
# Send Readiness signal to EngineCore process.
# Set linger here because we want to ensure the message has
# been sent before the context is closed.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
linger=10000) as ready_socket:
payload = pickle.dumps(worker_response_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
# Initialize device and loads weights
self.worker.init_device()
self.worker.load_model()
@@ -253,12 +312,10 @@ class WorkerProc:
rank: int,
distributed_init_method: str,
input_shm_handle, # Receive SchedulerOutput
) -> WorkerProcHandle:
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
# ZMQ path for worker to send ready message and shm_broadcast handle
# back to core process.
ready_path = get_open_zmq_ipc_path()
# (reader, writer)
reader, writer = context.Pipe(duplex=False)
process_kwargs = {
"vllm_config": vllm_config,
@@ -266,24 +323,57 @@ class WorkerProc:
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_path": ready_path,
"ready_pipe": (reader, writer),
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerProc.worker_main,
kwargs=process_kwargs,
name=f"VllmWorker-{rank}",
daemon=True)
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
proc.start()
proc.start()
writer.close()
return UnreadyWorkerProcHandle(proc, rank, reader)
# Wait for startup
worker_response_mq_handle = WorkerProc.wait_for_startup(
proc, ready_socket)
@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle]
) -> list[WorkerProcHandle]:
worker_response_mq = MessageQueue.create_from_handle(
worker_response_mq_handle, 0)
e = Exception("WorkerProc initialization failed due to "
"an exception in a background process. "
"See stack trace for root cause.")
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
[None] * len(unready_proc_handles))
while pipes:
ready = multiprocessing.connection.wait(pipes.keys())
for pipe in ready:
assert isinstance(pipe, Connection)
try:
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
raise e
# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))
except EOFError:
e.__suppress_context__ = True
raise e from None
finally:
# Close connection.
pipe.close()
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
self.rpc_broadcast_mq = None
@@ -312,51 +402,51 @@ class WorkerProc:
signal.signal(signal.SIGINT, signal_handler)
worker = None
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
try:
reader.close()
worker = WorkerProc(*args, **kwargs)
# Send READY once we know everything is loaded
ready_writer.send({
"status":
WorkerProc.READY_STR,
"handle":
worker.worker_response_mq.export_handle(),
})
# Ensure message queues are ready. Will deadlock if re-ordered.
# Must be kept consistent with the Executor
worker.rpc_broadcast_mq.wait_until_ready()
worker.worker_response_mq.wait_until_ready()
ready_writer.close()
ready_writer = None
worker.worker_busy_loop()
except SystemExit:
logger.debug("Worker interrupted.")
except Exception:
# worker_busy_loop sends exceptions to Executor
# for shutdown, but if there is an error in startup or an
# error with IPC itself, we need to alert the parent.
psutil.Process().parent().send_signal(signal.SIGUSR1)
raise
# NOTE: if an Exception arises in busy_loop, we send
# a FAILURE message over the MQ RPC to notify the Executor,
# which triggers system shutdown.
# TODO(rob): handle case where the MQ itself breaks.
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
else:
logger.exception("WorkerProc failed.")
# The parent sends a SIGTERM to all worker processes if
# any worker dies. Set this value so we don't re-throw
# SystemExit() to avoid zmq exceptions in __del__.
shutdown_requested = True
finally:
if ready_writer is not None:
ready_writer.close()
# Clean up once worker exits busy loop
if worker is not None:
worker.shutdown()
worker = None
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_socket: zmq.Socket,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
# Wait for Worker to send READY.
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")
if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")
message = ready_socket.recv_string()
assert message == WorkerProc.READY_STR
handle_frame = ready_socket.recv(copy=False)
handle = pickle.loads(handle_frame.buffer)
return handle
class ResponseStatus(Enum):
SUCCESS = auto()
@@ -365,7 +455,7 @@ class WorkerProc:
def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
try:
if isinstance(method, str):
@@ -377,12 +467,14 @@ class WorkerProc:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
e.add_note(traceback.format_exc())
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to
# string, only for logging purpose.
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
if not rank0_only or self.rank == 0:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e)))
continue
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))
if not rank0_only or self.rank == 0:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))