[Bug][Frontend] Improve ZMQ client robustness (#7443)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Optional, Set
|
||||
|
||||
@@ -83,7 +83,8 @@ async def lifespan(app: FastAPI):
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
with suppress(Exception):
|
||||
await async_engine_client.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
task = asyncio.create_task(_force_log())
|
||||
|
||||
@@ -10,10 +10,6 @@ from vllm.sampling_params import SamplingParams
|
||||
# Success string used for RPC instructions.
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
# Timeouts.
|
||||
VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000
|
||||
VLLM_RPC_HEALTH_TIMEOUT_MS = 10000
|
||||
|
||||
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -11,13 +11,12 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
|
||||
VLLM_RPC_HEALTH_TIMEOUT_MS,
|
||||
VLLM_RPC_SERVER_START_TIMEOUT_MS,
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
|
||||
VLLM_RPC_SUCCESS_STR,
|
||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -32,6 +31,17 @@ logger = init_logger(__name__)
|
||||
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
class RPCClientClosedError(Exception):
|
||||
"""Exception class raised when the client is used post-close.
|
||||
|
||||
The client can be closed, which closes the ZMQ context. This normally
|
||||
happens on server shutdown. In some cases, methods like abort and
|
||||
do_log_stats will still be called and then try to open a socket, which
|
||||
causes a ZMQError and creates a huge stack trace.
|
||||
So, we throw this error such that we can suppress it.
|
||||
"""
|
||||
|
||||
|
||||
class AsyncEngineRPCClient:
|
||||
"""
|
||||
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
|
||||
@@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
|
||||
|
||||
def __init__(self, rpc_path: str):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
|
||||
self._errored = False
|
||||
|
||||
# Maximum number of sockets that can be opened (typically 65536).
|
||||
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
|
||||
@@ -143,7 +155,6 @@ class AsyncEngineRPCClient:
|
||||
|
||||
# Wait until server is ready.
|
||||
await self._wait_for_server_rpc()
|
||||
self._errored = False
|
||||
|
||||
# Get the configs.
|
||||
self.model_config = await self._get_model_config_rpc()
|
||||
@@ -170,6 +181,15 @@ class AsyncEngineRPCClient:
|
||||
@contextmanager
|
||||
def to_proxy_socket(self):
|
||||
# Connect to the RPCServer via the proxy.
|
||||
|
||||
# Raise a sensible error if the client was already closed.
|
||||
# This can happen if a server shutdown is triggered but some coroutines
|
||||
# are still running requests.
|
||||
# There should not be a race condition with this check because we don't
|
||||
# yield to the event loop between here and opening the socket.
|
||||
if self.context.closed:
|
||||
raise RPCClientClosedError("The ZMQ client has already shut down")
|
||||
|
||||
# Note that we use DEALER to enable asynchronous communication
|
||||
# to enable streaming.
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
@@ -189,9 +209,18 @@ class AsyncEngineRPCClient:
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
|
||||
# Make sure the server responds
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
raise TimeoutError("Server didn't reply within "
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
# Await the data from the Server.
|
||||
data = cloudpickle.loads(await socket.recv())
|
||||
|
||||
if isinstance(data, Exception):
|
||||
# Re-raise exceptions returned by the server
|
||||
raise data
|
||||
|
||||
if not isinstance(data, expected_type):
|
||||
# LoRAConfig can be None.
|
||||
if expected_type == LoRAConfig and data is None:
|
||||
@@ -208,29 +237,28 @@ class AsyncEngineRPCClient:
|
||||
self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
timeout: Optional[int] = None,
|
||||
socket: Optional[zmq.asyncio.Socket] = None):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
|
||||
async def do_rpc_call(socket: zmq.asyncio.Socket,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
timeout=None):
|
||||
request: RPC_REQUEST_TYPE):
|
||||
|
||||
await socket.send_multipart([cloudpickle.dumps(request)])
|
||||
|
||||
if timeout is not None and await socket.poll(timeout=timeout) == 0:
|
||||
raise TimeoutError(f"Server didn't reply within {timeout} ms")
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
raise TimeoutError("Server didn't reply within "
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
return cloudpickle.loads(await socket.recv())
|
||||
|
||||
# Make a new socket connection.
|
||||
if socket is None:
|
||||
with self.to_proxy_socket() as socket:
|
||||
response = await do_rpc_call(socket, request, timeout)
|
||||
response = await do_rpc_call(socket, request)
|
||||
|
||||
# Use existing socket connection.
|
||||
else:
|
||||
response = await do_rpc_call(socket, request, timeout)
|
||||
response = await do_rpc_call(socket, request)
|
||||
|
||||
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
|
||||
if isinstance(response, Exception):
|
||||
@@ -255,8 +283,7 @@ class AsyncEngineRPCClient:
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_READY,
|
||||
error_message="Unable to start RPC Server",
|
||||
timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS)
|
||||
error_message="Unable to start RPC Server")
|
||||
|
||||
async def _get_model_config_rpc(self) -> ModelConfig:
|
||||
"""Get the ModelConfig object from the RPC Server"""
|
||||
@@ -308,17 +335,17 @@ class AsyncEngineRPCClient:
|
||||
|
||||
async def abort(self, request_id: str):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id),
|
||||
error_message=f"RPCAbortRequest {request_id} failed")
|
||||
with suppress(RPCClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id),
|
||||
error_message=f"RPCAbortRequest {request_id} failed")
|
||||
|
||||
async def do_log_stats(self):
|
||||
"""Send a DO_LOG_STATS signal to the RPC Server"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
||||
with suppress(RPCClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@@ -393,7 +420,6 @@ class AsyncEngineRPCClient:
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
|
||||
error_message="Got Unhealthy response from RPC Server",
|
||||
timeout=VLLM_RPC_HEALTH_TIMEOUT_MS,
|
||||
socket=socket)
|
||||
|
||||
async def encode(self, *args,
|
||||
|
||||
Reference in New Issue
Block a user