[V1][Frontend] Improve Shutdown And Logs (#11737)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com> Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,21 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import weakref
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from multiprocessing.connection import Connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
import zmq
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
@@ -26,8 +28,9 @@ from vllm.executor.multiproc_worker_utils import (
|
||||
_add_prefix, set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
get_open_port)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -35,6 +38,8 @@ logger = init_logger(__name__)
|
||||
POLLING_TIMEOUT_MS = 5000
|
||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||
|
||||
EXECUTE_MODEL_TIMEOUT_S = 30
|
||||
|
||||
|
||||
class MultiprocExecutor(Executor):
|
||||
|
||||
@@ -42,19 +47,9 @@ class MultiprocExecutor(Executor):
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
|
||||
# The child processes will send SIGUSR1 when unrecoverable
|
||||
# errors happen.
|
||||
def sigusr1_handler(signum, frame):
|
||||
logger.fatal(
|
||||
"MulitprocExecutor got fatal signal from worker processes, "
|
||||
"shutting down. See stack trace above for root cause issue.")
|
||||
# Propagate error up to parent process.
|
||||
parent_process = psutil.Process().parent()
|
||||
parent_process.send_signal(signal.SIGUSR1)
|
||||
self.shutdown()
|
||||
|
||||
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: Optional[FailureCallback] = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
@@ -78,28 +73,94 @@ class MultiprocExecutor(Executor):
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
self.workers: list[WorkerProcHandle] = []
|
||||
for rank in range(self.world_size):
|
||||
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
scheduler_output_handle)
|
||||
self.workers.append(worker)
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
))
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
# Clean up the worker procs if there was a failure.
|
||||
self._ensure_worker_termination(
|
||||
[w.proc for w in unready_workers])
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
|
||||
# Monitors worker process liveness. If any die unexpectedly,
|
||||
# logs an error, shuts down the executor and invokes the failure
|
||||
# callback to inform the engine.
|
||||
def monitor_workers():
|
||||
sentinels = [h.proc.sentinel for h in workers]
|
||||
died = multiprocessing.connection.wait(sentinels)
|
||||
_self = self_ref()
|
||||
if not _self or getattr(_self, 'shutting_down', False):
|
||||
return
|
||||
_self.is_failed = True
|
||||
proc_name = next(h.proc.name for h in workers
|
||||
if h.proc.sentinel == died[0])
|
||||
logger.error(
|
||||
"Worker proc %s died unexpectedly, "
|
||||
"shutting down executor.", proc_name)
|
||||
_self.shutdown()
|
||||
callback = _self.failure_callback
|
||||
if callback is not None:
|
||||
_self.failure_callback = None
|
||||
callback()
|
||||
|
||||
Thread(target=monitor_workers,
|
||||
daemon=True,
|
||||
name="MultiprocWorkerMonitor").start()
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
if self.is_failed:
|
||||
callback()
|
||||
else:
|
||||
self.failure_callback = callback
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||
(output, ) = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ),
|
||||
rank0_reply_only=True,
|
||||
timeout=EXECUTE_MODEL_TIMEOUT_S)
|
||||
return output
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
timeout: Optional[float] = 180.0,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None) -> list[Any]:
|
||||
kwargs: Optional[dict] = None,
|
||||
rank0_reply_only: bool = False) -> list[Any]:
|
||||
start_time = time.monotonic()
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||
# and unpack them in the method of every worker, because every worker
|
||||
# knows their own rank.
|
||||
@@ -109,30 +170,30 @@ class MultiprocExecutor(Executor):
|
||||
else:
|
||||
send_method = cloudpickle.dumps(
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
|
||||
self.rpc_broadcast_mq.enqueue(
|
||||
(send_method, args, kwargs, rank0_reply_only))
|
||||
|
||||
responses = [None] * self.world_size
|
||||
for w in self.workers:
|
||||
workers = (self.workers[0], ) if rank0_reply_only else self.workers
|
||||
responses = [None] * len(workers)
|
||||
for w in workers:
|
||||
dequeue_timeout = timeout - (time.monotonic() - start_time
|
||||
) if timeout is not None else None
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
timeout=dequeue_timeout)
|
||||
timeout=dequeue_timeout, cancel=self.shutdown_event)
|
||||
|
||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||
raise RuntimeError(
|
||||
"Worker failed with error %s, please check the"
|
||||
" stack trace above for the root cause", result)
|
||||
f"Worker failed with error '{result}', please check the"
|
||||
" stack trace above for the root cause")
|
||||
|
||||
responses[w.rank] = result
|
||||
|
||||
return responses
|
||||
except TimeoutError as e:
|
||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||
except Exception as e:
|
||||
# Re-raise any other exceptions
|
||||
raise e
|
||||
|
||||
def _ensure_worker_termination(self):
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||
received termination requests. Waits for processing, then sends
|
||||
termination and kill signals if needed."""
|
||||
@@ -150,7 +211,7 @@ class MultiprocExecutor(Executor):
|
||||
return False
|
||||
|
||||
# Send SIGTERM if still running
|
||||
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
|
||||
active_procs = [proc for proc in worker_procs if proc.is_alive()]
|
||||
for p in active_procs:
|
||||
p.terminate()
|
||||
if not wait_for_termination(active_procs, 4):
|
||||
@@ -159,22 +220,14 @@ class MultiprocExecutor(Executor):
|
||||
for p in active_procs:
|
||||
p.kill()
|
||||
|
||||
self._cleanup_sockets()
|
||||
|
||||
def _cleanup_sockets(self):
|
||||
for w in self.workers:
|
||||
# Remove the zmq ipc socket file
|
||||
socket_path = w.ready_path.replace("ipc://", "")
|
||||
if os and os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
|
||||
def shutdown(self):
|
||||
"""Properly shut down the executor and its workers"""
|
||||
if not getattr(self, 'shutting_down', False):
|
||||
self.shutting_down = True
|
||||
self.shutdown_event.set()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq = None
|
||||
self._ensure_worker_termination()
|
||||
self._ensure_worker_termination([w.proc for w in self.workers])
|
||||
|
||||
self.rpc_broadcast_mq = None
|
||||
|
||||
@@ -183,13 +236,30 @@ class MultiprocExecutor(Executor):
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnreadyWorkerProcHandle:
|
||||
"""WorkerProcess handle before READY."""
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
ready_pipe: Connection
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
ready_path: str
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
cls, unready_handle: UnreadyWorkerProcHandle,
|
||||
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
|
||||
return cls(
|
||||
proc=unready_handle.proc,
|
||||
rank=unready_handle.rank,
|
||||
worker_response_mq=worker_response_mq,
|
||||
)
|
||||
|
||||
|
||||
class WorkerProc:
|
||||
"""Wrapper that runs one Worker in a separate process."""
|
||||
@@ -203,7 +273,6 @@ class WorkerProc:
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle: Handle,
|
||||
ready_path: str,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
@@ -231,18 +300,8 @@ class WorkerProc:
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
worker_response_mq_handle = self.worker_response_mq.export_handle()
|
||||
|
||||
# Send Readiness signal to EngineCore process.
|
||||
# Set linger here because we want to ensure the message has
|
||||
# been sent before the context is closed.
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
|
||||
linger=10000) as ready_socket:
|
||||
payload = pickle.dumps(worker_response_mq_handle,
|
||||
protocol=pickle.HIGHEST_PROTOCOL)
|
||||
ready_socket.send_string(WorkerProc.READY_STR)
|
||||
ready_socket.send(payload)
|
||||
|
||||
# Initialize device and loads weights
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
@@ -253,12 +312,10 @@ class WorkerProc:
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
) -> WorkerProcHandle:
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
|
||||
# ZMQ path for worker to send ready message and shm_broadcast handle
|
||||
# back to core process.
|
||||
ready_path = get_open_zmq_ipc_path()
|
||||
# (reader, writer)
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
process_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
@@ -266,24 +323,57 @@ class WorkerProc:
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_path": ready_path,
|
||||
"ready_pipe": (reader, writer),
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
name=f"VllmWorker-{rank}",
|
||||
daemon=True)
|
||||
|
||||
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
|
||||
proc.start()
|
||||
proc.start()
|
||||
writer.close()
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader)
|
||||
|
||||
# Wait for startup
|
||||
worker_response_mq_handle = WorkerProc.wait_for_startup(
|
||||
proc, ready_socket)
|
||||
@staticmethod
|
||||
def wait_for_ready(
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle]
|
||||
) -> list[WorkerProcHandle]:
|
||||
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
worker_response_mq_handle, 0)
|
||||
e = Exception("WorkerProc initialization failed due to "
|
||||
"an exception in a background process. "
|
||||
"See stack trace for root cause.")
|
||||
|
||||
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
|
||||
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
|
||||
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
|
||||
[None] * len(unready_proc_handles))
|
||||
while pipes:
|
||||
ready = multiprocessing.connection.wait(pipes.keys())
|
||||
for pipe in ready:
|
||||
assert isinstance(pipe, Connection)
|
||||
try:
|
||||
# Wait until the WorkerProc is ready.
|
||||
unready_proc_handle = pipes.pop(pipe)
|
||||
response: dict[str, Any] = pipe.recv()
|
||||
if response["status"] != "READY":
|
||||
raise e
|
||||
|
||||
# Extract the message queue handle.
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
response["handle"], 0)
|
||||
ready_proc_handles[unready_proc_handle.rank] = (
|
||||
WorkerProcHandle.from_unready_handle(
|
||||
unready_proc_handle, worker_response_mq))
|
||||
|
||||
except EOFError:
|
||||
e.__suppress_context__ = True
|
||||
raise e from None
|
||||
|
||||
finally:
|
||||
# Close connection.
|
||||
pipe.close()
|
||||
|
||||
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||
|
||||
def shutdown(self):
|
||||
self.rpc_broadcast_mq = None
|
||||
@@ -312,51 +402,51 @@ class WorkerProc:
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
worker = None
|
||||
# tuple[Connection, Connection]
|
||||
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||
try:
|
||||
reader.close()
|
||||
worker = WorkerProc(*args, **kwargs)
|
||||
|
||||
# Send READY once we know everything is loaded
|
||||
ready_writer.send({
|
||||
"status":
|
||||
WorkerProc.READY_STR,
|
||||
"handle":
|
||||
worker.worker_response_mq.export_handle(),
|
||||
})
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
worker.worker_response_mq.wait_until_ready()
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
worker.worker_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("Worker interrupted.")
|
||||
|
||||
except Exception:
|
||||
# worker_busy_loop sends exceptions to Executor
|
||||
# for shutdown, but if there is an error in startup or an
|
||||
# error with IPC itself, we need to alert the parent.
|
||||
psutil.Process().parent().send_signal(signal.SIGUSR1)
|
||||
raise
|
||||
# NOTE: if an Exception arises in busy_loop, we send
|
||||
# a FAILURE message over the MQ RPC to notify the Executor,
|
||||
# which triggers system shutdown.
|
||||
# TODO(rob): handle case where the MQ itself breaks.
|
||||
|
||||
if ready_writer is not None:
|
||||
logger.exception("WorkerProc failed to start.")
|
||||
else:
|
||||
logger.exception("WorkerProc failed.")
|
||||
|
||||
# The parent sends a SIGTERM to all worker processes if
|
||||
# any worker dies. Set this value so we don't re-throw
|
||||
# SystemExit() to avoid zmq exceptions in __del__.
|
||||
shutdown_requested = True
|
||||
|
||||
finally:
|
||||
if ready_writer is not None:
|
||||
ready_writer.close()
|
||||
# Clean up once worker exits busy loop
|
||||
if worker is not None:
|
||||
worker.shutdown()
|
||||
worker = None
|
||||
|
||||
@staticmethod
|
||||
def wait_for_startup(
|
||||
proc: BaseProcess,
|
||||
ready_socket: zmq.Socket,
|
||||
) -> Optional[Handle]:
|
||||
"""Wait until the Worker is ready."""
|
||||
|
||||
# Wait for Worker to send READY.
|
||||
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
logger.debug("Waiting for WorkerProc to startup.")
|
||||
|
||||
if not proc.is_alive():
|
||||
raise RuntimeError("WorkerProc failed to start.")
|
||||
|
||||
message = ready_socket.recv_string()
|
||||
assert message == WorkerProc.READY_STR
|
||||
handle_frame = ready_socket.recv(copy=False)
|
||||
handle = pickle.loads(handle_frame.buffer)
|
||||
return handle
|
||||
|
||||
class ResponseStatus(Enum):
|
||||
SUCCESS = auto()
|
||||
@@ -365,7 +455,7 @@ class WorkerProc:
|
||||
def worker_busy_loop(self):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
||||
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
@@ -377,12 +467,14 @@ class WorkerProc:
|
||||
# Notes have been introduced in python 3.11
|
||||
if hasattr(e, "add_note"):
|
||||
e.add_note(traceback.format_exc())
|
||||
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
|
||||
logger.exception("WorkerProc hit an exception.")
|
||||
# exception might not be serializable, so we convert it to
|
||||
# string, only for logging purpose.
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||
if not rank0_only or self.rank == 0:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||
continue
|
||||
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||
if not rank0_only or self.rank == 0:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||
|
||||
Reference in New Issue
Block a user