[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)

This commit is contained in:
Robert Shaw
2024-12-27 20:45:08 -05:00
committed by GitHub
parent a60731247f
commit df04dffade
12 changed files with 242 additions and 210 deletions

View File

@@ -3,20 +3,19 @@ import queue
import signal
import threading
import time
from dataclasses import dataclass
from multiprocessing.process import BaseProcess
from multiprocessing.connection import Connection
from typing import List, Tuple, Type
import psutil
import zmq
import zmq.asyncio
from msgspec import msgpack
from vllm.config import CacheConfig, VllmConfig
from vllm.executor.multiproc_worker_utils import get_mp_context
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
@@ -25,14 +24,13 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.utils import make_zmq_socket
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = POLLING_TIMEOUT_S
LOGGING_TIME_S = 5
class EngineCore:
@@ -42,9 +40,10 @@ class EngineCore:
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
log_stats: bool = False,
):
assert vllm_config.model_config.runner_type != "pooling"
self.log_stats = log_stats
logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
@@ -134,29 +133,19 @@ class EngineCore:
self.model_executor.profile(is_start)
@dataclass
class EngineCoreProcHandle:
proc: BaseProcess
ready_path: str
input_path: str
output_path: str
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
READY_STR = "READY"
def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
super().__init__(vllm_config, executor_class, usage_context)
super().__init__(vllm_config, executor_class, log_stats)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
@@ -173,68 +162,7 @@ class EngineCoreProc(EngineCore):
daemon=True).start()
# Send Readiness signal to EngineClient.
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
ready_socket.send_string(EngineCoreProc.READY_STR)
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
) -> None:
"""Wait until the EngineCore is ready."""
try:
sync_ctx = zmq.Context() # type: ignore[attr-defined]
socket = sync_ctx.socket(zmq.constants.PULL)
socket.connect(ready_path)
# Wait for EngineCore to send EngineCoreProc.READY_STR.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for EngineCoreProc to startup.")
if not proc.is_alive():
raise RuntimeError("EngineCoreProc failed to start.")
message = socket.recv_string()
assert message == EngineCoreProc.READY_STR
except BaseException as e:
logger.exception(e)
raise e
finally:
sync_ctx.destroy(linger=0)
@staticmethod
def make_engine_core_process(
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
) -> EngineCoreProcHandle:
context = get_mp_context()
process_kwargs = {
"input_path": input_path,
"output_path": output_path,
"ready_path": ready_path,
"vllm_config": vllm_config,
"executor_class": executor_class,
"usage_context": usage_context,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=EngineCoreProc.run_engine_core,
kwargs=process_kwargs)
proc.start()
# Wait for startup
EngineCoreProc.wait_for_startup(proc, ready_path)
return EngineCoreProcHandle(proc=proc,
ready_path=ready_path,
input_path=input_path,
output_path=output_path)
ready_pipe.send({"status": "READY"})
@staticmethod
def run_engine_core(*args, **kwargs):
@@ -258,6 +186,7 @@ class EngineCoreProc(EngineCore):
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
parent_process = psutil.Process().parent()
engine_core = None
try:
engine_core = EngineCoreProc(*args, **kwargs)
@@ -266,9 +195,10 @@ class EngineCoreProc(EngineCore):
except SystemExit:
logger.debug("EngineCore interrupted.")
except BaseException as e:
logger.exception(e)
raise e
except Exception:
traceback = get_exception_traceback()
logger.error("EngineCore hit an exception: %s", traceback)
parent_process.send_signal(signal.SIGQUIT)
finally:
if engine_core is not None:
@@ -309,6 +239,9 @@ class EngineCoreProc(EngineCore):
def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""
if not self.log_stats:
return
now = time.time()
if now - self._last_logging_time > LOGGING_TIME_S:
@@ -339,7 +272,7 @@ class EngineCoreProc(EngineCore):
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
@@ -367,7 +300,7 @@ class EngineCoreProc(EngineCore):
# Reuse send buffer.
buffer = bytearray()
with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)