[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
This commit is contained in:
@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import prometheus_client
|
||||
import regex as re
|
||||
@@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
@@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
@@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||
args: Namespace,
|
||||
client_config: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
engine_args, args.disable_frontend_multiprocessing,
|
||||
client_config) as engine:
|
||||
yield engine
|
||||
|
||||
|
||||
@@ -157,6 +163,7 @@ async def build_async_engine_client(
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
client_config: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
@@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_index = client_config.pop(
|
||||
"client_index") if client_config else 0
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_requests=engine_args.disable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats)
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
client_addresses=client_config,
|
||||
client_index=client_index)
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
await async_llm.reset_mm_cache()
|
||||
@@ -318,22 +329,9 @@ class PrometheusResponse(Response):
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = REGISTRY
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
@@ -1256,16 +1254,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
||||
return sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
log_non_default_args(args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
def validate_api_server_args(args):
|
||||
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valid_tool_parses:
|
||||
and args.tool_call_parser not in valid_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
||||
|
||||
@@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
f"invalid reasoning parser: {args.reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
||||
|
||||
|
||||
def setup_server(args):
|
||||
"""Validate API server args, set up signal handler, create socket
|
||||
ready to serve."""
|
||||
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
log_non_default_args(args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
validate_api_server_args(args)
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
@@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
addr, port = sock_addr
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
host_part = f"[{addr}]" if is_valid_ipv6_address(
|
||||
addr) else addr or "0.0.0.0"
|
||||
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
|
||||
|
||||
return listen_address, sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
"""Run a single-worker API server."""
|
||||
listen_address, sock = setup_server(args)
|
||||
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||
|
||||
|
||||
async def run_server_worker(listen_address,
|
||||
sock,
|
||||
args,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Run a single API server worker."""
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
async with build_async_engine_client(args, client_config) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||
|
||||
def _listen_addr(a: str) -> str:
|
||||
if is_valid_ipv6_address(a):
|
||||
return '[' + a + ']'
|
||||
return a or "0.0.0.0"
|
||||
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
logger.info("Starting vLLM API server on http%s://%s:%d",
|
||||
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
||||
sock_addr[1])
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index,
|
||||
listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
|
||||
Reference in New Issue
Block a user