[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import weakref
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
|
||||
overload)
|
||||
|
||||
import zmq
|
||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, kill_process_tree
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -77,27 +77,58 @@ class ConstantList(Generic[T], Sequence):
|
||||
return len(self._x)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def make_zmq_socket(
|
||||
path: str,
|
||||
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
class BackgroundProcHandle:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of background processes used by the AsyncLLM and LLMEngine.
|
||||
"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
socket = ctx.socket(type)
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
process_name: str,
|
||||
target_fn: Callable,
|
||||
process_kwargs: Dict[Any, Any],
|
||||
):
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
context = get_mp_context()
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
yield socket
|
||||
assert ("ready_pipe" not in process_kwargs
|
||||
and "input_path" not in process_kwargs
|
||||
and "output_path" not in process_kwargs)
|
||||
process_kwargs["ready_pipe"] = writer
|
||||
process_kwargs["input_path"] = input_path
|
||||
process_kwargs["output_path"] = output_path
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Worker had Keyboard Interrupt.")
|
||||
# Run Detokenizer busy loop in background process.
|
||||
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
|
||||
self.proc.start()
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
# Wait for startup.
|
||||
if reader.recv()["status"] != "READY":
|
||||
raise RuntimeError(f"{process_name} initialization failed. "
|
||||
"See root cause above.")
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
# Shutdown the process if needed.
|
||||
if hasattr(self, "proc") and self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
self.proc.join(5)
|
||||
|
||||
if self.proc.is_alive():
|
||||
kill_process_tree(self.proc.pid)
|
||||
|
||||
# Remove zmq ipc socket files
|
||||
ipc_sockets = [self.output_path, self.input_path]
|
||||
for ipc_socket in ipc_sockets:
|
||||
socket_file = ipc_socket.replace("ipc://", "")
|
||||
if os and os.path.exists(socket_file):
|
||||
os.remove(socket_file)
|
||||
|
||||
Reference in New Issue
Block a user