Files
vllm/vllm/entrypoints/openai/api_server.py
Patrick von Platen 10152d2194 [Realtime API] Adds minimal realtime API based on websockets (#33187)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
2026-01-30 18:41:29 +08:00

506 lines
18 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import inspect
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
import signal
import socket
import tempfile
from argparse import Namespace
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
import uvloop
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from starlette.datastructures import State
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
get_uvicorn_log_config,
http_exception_handler,
lifespan,
log_response,
validation_exception_handler,
)
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
cli_env_setup,
log_non_default_args,
log_version_and_model,
process_lora_modules,
)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger("vllm.entrypoints.openai.api_server")
@asynccontextmanager
async def build_async_engine_client(
args: Namespace,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: bool | None = None,
client_config: dict[str, Any] | None = None,
) -> AsyncIterator[EngineClient]:
if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
# The executor is expected to be mp.
# Pre-import heavy modules in the forkserver process
logger.debug("Setup forkserver with pre-imports")
multiprocessing.set_start_method("forkserver")
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
forkserver.ensure_running()
logger.debug("Forkserver setup complete!")
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
if client_config:
engine_args._api_process_count = client_config.get("client_count", 1)
engine_args._api_process_rank = client_config.get("client_index", 0)
if disable_frontend_multiprocessing is None:
disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
async with build_async_engine_client_from_engine_args(
engine_args,
usage_context=usage_context,
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
client_config=client_config,
) as engine:
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
*,
usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
disable_frontend_multiprocessing: bool = False,
client_config: dict[str, Any] | None = None,
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Create the EngineConfig (determines if we can use V1).
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if disable_frontend_multiprocessing:
logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: AsyncLLM | None = None
# Don't mutate the input client_config
client_config = dict(client_config) if client_config else {}
client_count = client_config.pop("client_count", 1)
client_index = client_config.pop("client_index", 0)
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
enable_log_requests=engine_args.enable_log_requests,
aggregate_engine_logging=engine_args.aggregate_engine_logging,
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_count=client_count,
client_index=client_index,
)
# Don't keep the dummy data in memory
assert async_llm is not None
await async_llm.reset_mm_cache()
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(
openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
)
elif args.enable_offline_docs:
app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.state.args = args
from vllm.entrypoints.openai.basic.api_router import register_basic_api_routers
register_basic_api_routers(app)
from vllm.entrypoints.serve import register_vllm_serve_api_routers
register_vllm_serve_api_routers(app)
from vllm.entrypoints.openai.models.api_router import (
attach_router as register_models_api_router,
)
register_models_api_router(app)
from vllm.entrypoints.sagemaker.api_router import (
attach_router as register_sagemaker_api_router,
)
register_sagemaker_api_router(app, supported_tasks)
if "generate" in supported_tasks:
from vllm.entrypoints.openai.generate.api_router import (
register_generate_api_routers,
)
register_generate_api_routers(app)
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.translations.api_router import (
attach_router as register_translations_api_router,
)
register_translations_api_router(app)
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import (
attach_router as register_realtime_api_router,
)
register_realtime_api_router(app)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import register_pooling_api_routers
register_pooling_api_routers(app, supported_tasks)
app.root_path = args.root_path
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
app.exception_handler(HTTPException)(http_exception_handler)
app.exception_handler(RequestValidationError)(validation_exception_handler)
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware
app.add_middleware(AuthenticationMiddleware, tokens=tokens)
if args.enable_request_id_headers:
from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware
app.add_middleware(XRequestIdMiddleware)
# Add scaling middleware to check for scaling state
app.add_middleware(ScalingMiddleware)
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning(
"CAUTION: Enabling log response in the API Server. "
"This can include sensitive information and should be "
"avoided in production."
)
app.middleware("http")(log_response)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)
app = sagemaker_standards_bootstrap(app)
return app
async def init_app_state(
engine_client: EngineClient,
state: State,
args: Namespace,
supported_tasks: tuple["SupportedTask", ...],
) -> None:
vllm_config = engine_client.vllm_config
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config
state.args = args
resolved_chat_template = load_chat_template(args.chat_template)
# Merge default_mm_loras into the static lora_modules
default_mm_loras = (
vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None
else {}
)
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
)
await state.openai_serving_models.init_static_loras()
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks:
from vllm.entrypoints.openai.generate.api_router import init_generate_state
await init_generate_state(
engine_client, state, args, request_logger, supported_tasks
)
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.translations.api_router import (
init_transcription_state,
)
init_transcription_state(
engine_client, state, args, request_logger, supported_tasks
)
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import init_realtime_state
init_realtime_state(engine_client, state, args, request_logger, supported_tasks)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import init_pooling_state
init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
family = socket.AF_INET
if is_valid_ipv6_address(addr[0]):
family = socket.AF_INET6
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(addr)
return sock
def create_server_unix_socket(path: str) -> socket.socket:
sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
sock.bind(path)
return sock
def validate_api_server_args(args):
valid_tool_parses = ToolParserManager.list_registered()
if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
raise KeyError(
f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})"
)
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
log_version_and_model(logger, VLLM_VERSION, args.model)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
validate_api_server_args(args)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
if args.uds:
sock = create_server_unix_socket(args.uds)
else:
sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit()
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
if args.uds:
listen_address = f"unix:{args.uds}"
else:
addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
# Add process-specific prefix to stdout and stderr.
decorate_logs("APIServer")
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
async def run_server_worker(
listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
# Get uvicorn log config (from file or with endpoint filter)
log_config = get_uvicorn_log_config(args)
if log_config is not None:
uvicorn_kwargs["log_config"] = log_config
async with build_async_engine_client(
args,
client_config=client_config,
) as engine_client:
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
app = build_app(args, supported_tasks)
await init_app_state(engine_client, app.state, args, supported_tasks)
logger.info(
"Starting vLLM API server %d on %s",
engine_client.vllm_config.parallel_config._api_process_rank,
listen_address,
)
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
ssl_ciphers=args.ssl_ciphers,
h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
h11_max_header_count=args.h11_max_header_count,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
try:
await shutdown_task
finally:
sock.close()
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints.
cli_env_setup()
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server."
)
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))