[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)

This commit is contained in:
Robert Shaw
2024-12-27 20:45:08 -05:00
committed by GitHub
parent a60731247f
commit df04dffade
12 changed files with 242 additions and 210 deletions

View File

@@ -1,11 +1,11 @@
import os
import weakref
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
overload)
import zmq
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
Union, overload)
from vllm.logger import init_logger
from vllm.utils import get_mp_context, kill_process_tree
logger = init_logger(__name__)
@@ -77,27 +77,58 @@ class ConstantList(Generic[T], Sequence):
return len(self._x)
@contextmanager
def make_zmq_socket(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
class BackgroundProcHandle:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
ctx = zmq.Context() # type: ignore[attr-defined]
try:
socket = ctx.socket(type)
def __init__(
self,
input_path: str,
output_path: str,
process_name: str,
target_fn: Callable,
process_kwargs: Dict[Any, Any],
):
self._finalizer = weakref.finalize(self, self.shutdown)
if type == zmq.constants.PULL:
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
context = get_mp_context()
reader, writer = context.Pipe(duplex=False)
yield socket
assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path
self.input_path = input_path
self.output_path = output_path
except KeyboardInterrupt:
logger.debug("Worker had Keyboard Interrupt.")
# Run Detokenizer busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
self.proc.start()
finally:
ctx.destroy(linger=0)
# Wait for startup.
if reader.recv()["status"] != "READY":
raise RuntimeError(f"{process_name} initialization failed. "
"See root cause above.")
def __del__(self):
self.shutdown()
def shutdown(self):
# Shutdown the process if needed.
if hasattr(self, "proc") and self.proc.is_alive():
self.proc.terminate()
self.proc.join(5)
if self.proc.is_alive():
kill_process_tree(self.proc.pid)
# Remove zmq ipc socket files
ipc_sockets = [self.output_path, self.input_path]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)