[Perf] API-server scaleout with many-to-many server-engine comms (#17546)

This commit is contained in:
Nick Hill
2025-05-30 08:17:00 -07:00
committed by GitHub
parent 84ec470fca
commit 2dbe8c0774
26 changed files with 1828 additions and 436 deletions

View File

@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from json import JSONDecodeError
from typing import Annotated, Optional
from typing import Annotated, Any, Optional
import prometheus_client
import regex as re
@@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State
from starlette.routing import Mount
@@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.v1.metrics.prometheus import get_prometheus_registry
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
@@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
args: Namespace,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
engine_args, args.disable_frontend_multiprocessing,
client_config) as engine:
yield engine
@@ -157,6 +163,7 @@ async def build_async_engine_client(
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
@@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
client_index = client_config.pop(
"client_index") if client_config else 0
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats)
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_index=client_index)
# Don't keep the dummy data in memory
await async_llm.reset_mm_cache()
@@ -318,22 +329,9 @@ class PrometheusResponse(Response):
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
"""Mount prometheus metrics to a FastAPI app."""
registry = REGISTRY
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
registry = get_prometheus_registry()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
@@ -1256,16 +1254,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return sock
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valid_tool_parses:
and args.tool_call_parser not in valid_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})")
@@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f"invalid reasoning parser: {args.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
validate_api_server_args(args)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
@@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(
addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
async def run_server_worker(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
server_index = client_config.get("client_index", 0) if client_config else 0
async with build_async_engine_client(args, client_config) as engine_client:
app = build_app(args)
vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
return '[' + a + ']'
return a or "0.0.0.0"
is_ssl = args.ssl_keyfile and args.ssl_certfile
logger.info("Starting vLLM API server on http%s://%s:%d",
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
sock_addr[1])
logger.info("Starting vLLM API server %d on %s", server_index,
listen_address)
shutdown_task = await serve_http(
app,
sock=sock,