[Bug][Frontend] Improve ZMQ client robustness (#7443)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-08-21 20:18:11 -06:00
committed by GitHub
parent df1a21131d
commit cde9183b40
6 changed files with 176 additions and 28 deletions

View File

@@ -6,7 +6,7 @@ import os
import re
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set
@@ -83,7 +83,8 @@ async def lifespan(app: FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
with suppress(Exception):
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())

View File

@@ -10,10 +10,6 @@ from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS"
# Timeouts.
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000

View File

@@ -1,5 +1,5 @@
import asyncio
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Mapping, Optional
from uuid import uuid4
@@ -11,13 +11,12 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTH_TIMEOUT_MS,
VLLM_RPC_SERVER_START_TIMEOUT_MS,
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -32,6 +31,17 @@ logger = init_logger(__name__)
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
@@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
self._errored = False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
@@ -143,7 +155,6 @@ class AsyncEngineRPCClient:
# Wait until server is ready.
await self._wait_for_server_rpc()
self._errored = False
# Get the configs.
self.model_config = await self._get_model_config_rpc()
@@ -170,6 +181,15 @@ class AsyncEngineRPCClient:
@contextmanager
def to_proxy_socket(self):
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
@@ -189,9 +209,18 @@ class AsyncEngineRPCClient:
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
@@ -208,29 +237,28 @@ class AsyncEngineRPCClient:
self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None,
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE,
timeout=None):
request: RPC_REQUEST_TYPE):
await socket.send_multipart([cloudpickle.dumps(request)])
if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"Server didn't reply within {timeout} ms")
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
return cloudpickle.loads(await socket.recv())
# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)
# Use existing socket connection.
else:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception):
@@ -255,8 +283,7 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server",
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
error_message="Unable to start RPC Server")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
@@ -308,17 +335,17 @@ class AsyncEngineRPCClient:
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
@property
def is_running(self) -> bool:
@@ -393,7 +420,6 @@ class AsyncEngineRPCClient:
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
socket=socket)
async def encode(self, *args,