[P/D] Add a shutdown method to the Connector API (#22699)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2025-09-08 14:07:00 +08:00
committed by GitHub
parent 8c892b1831
commit 61aa4b2901
10 changed files with 52 additions and 12 deletions

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os
import pickle
import queue
import signal
@@ -507,6 +506,7 @@ class WorkerProc:
return cast(list[WorkerProcHandle], ready_proc_handles)
def shutdown(self):
self.worker.shutdown()
self.rpc_broadcast_mq = None
self.worker_response_mq = None
destroy_model_parallel()
@@ -536,7 +536,7 @@ class WorkerProc:
# tuple[Connection, Connection]
reader, ready_writer = kwargs.pop("ready_pipe")
death_pipe = kwargs.pop("death_pipe", None)
shutdown_event = threading.Event()
# Start death monitoring thread if death_pipe is provided
if death_pipe is not None:
@@ -548,7 +548,7 @@ class WorkerProc:
# Parent process has exited, terminate this worker
logger.info("Parent process exited, terminating worker")
# Send signal to self to trigger clean shutdown
os.kill(os.getpid(), signal.SIGTERM)
shutdown_event.set()
except Exception as e:
logger.warning("Death monitoring error: %s", e)
@@ -576,7 +576,7 @@ class WorkerProc:
ready_writer.close()
ready_writer = None
worker.worker_busy_loop()
worker.worker_busy_loop(cancel=shutdown_event)
except Exception:
# NOTE: if an Exception arises in busy_loop, we send
@@ -586,6 +586,8 @@ class WorkerProc:
if ready_writer is not None:
logger.exception("WorkerProc failed to start.")
elif shutdown_event.is_set():
logger.info("WorkerProc shutting down.")
else:
logger.exception("WorkerProc failed.")
@@ -637,11 +639,11 @@ class WorkerProc:
output = self.async_output_queue.get()
self.enqueue_output(output)
def worker_busy_loop(self):
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
"""Main busy loop for Multiprocessing Workers"""
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
cancel=cancel)
try:
if isinstance(method, str):
func = getattr(self.worker, method)