[P/D] Add a shutdown method to the Connector API (#22699)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
@@ -507,6 +506,7 @@ class WorkerProc:
|
||||
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||
|
||||
def shutdown(self):
|
||||
self.worker.shutdown()
|
||||
self.rpc_broadcast_mq = None
|
||||
self.worker_response_mq = None
|
||||
destroy_model_parallel()
|
||||
@@ -536,7 +536,7 @@ class WorkerProc:
|
||||
# tuple[Connection, Connection]
|
||||
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||
death_pipe = kwargs.pop("death_pipe", None)
|
||||
|
||||
shutdown_event = threading.Event()
|
||||
# Start death monitoring thread if death_pipe is provided
|
||||
if death_pipe is not None:
|
||||
|
||||
@@ -548,7 +548,7 @@ class WorkerProc:
|
||||
# Parent process has exited, terminate this worker
|
||||
logger.info("Parent process exited, terminating worker")
|
||||
# Send signal to self to trigger clean shutdown
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
shutdown_event.set()
|
||||
except Exception as e:
|
||||
logger.warning("Death monitoring error: %s", e)
|
||||
|
||||
@@ -576,7 +576,7 @@ class WorkerProc:
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
worker.worker_busy_loop()
|
||||
worker.worker_busy_loop(cancel=shutdown_event)
|
||||
|
||||
except Exception:
|
||||
# NOTE: if an Exception arises in busy_loop, we send
|
||||
@@ -586,6 +586,8 @@ class WorkerProc:
|
||||
|
||||
if ready_writer is not None:
|
||||
logger.exception("WorkerProc failed to start.")
|
||||
elif shutdown_event.is_set():
|
||||
logger.info("WorkerProc shutting down.")
|
||||
else:
|
||||
logger.exception("WorkerProc failed.")
|
||||
|
||||
@@ -637,11 +639,11 @@ class WorkerProc:
|
||||
output = self.async_output_queue.get()
|
||||
self.enqueue_output(output)
|
||||
|
||||
def worker_busy_loop(self):
|
||||
def worker_busy_loop(self, cancel: Optional[threading.Event] = None):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue(
|
||||
cancel=cancel)
|
||||
try:
|
||||
if isinstance(method, str):
|
||||
func = getattr(self.worker, method)
|
||||
|
||||
Reference in New Issue
Block a user