[Frontend] Cleanup api server (#33158)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -1,34 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import multiprocessing.forkserver as forkserver
|
||||
import os
|
||||
import secrets
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncIterator, Awaitable
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import pydantic
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders, State
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from starlette.datastructures import State
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@@ -37,15 +26,16 @@ from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import (
|
||||
OpenAIServingModels,
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.server_utils import (
|
||||
get_uvicorn_log_config,
|
||||
http_exception_handler,
|
||||
lifespan,
|
||||
log_response,
|
||||
validation_exception_handler,
|
||||
)
|
||||
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
ScalingMiddleware,
|
||||
)
|
||||
@@ -55,16 +45,13 @@ from vllm.entrypoints.utils import (
|
||||
log_non_default_args,
|
||||
log_version_and_model,
|
||||
process_lora_modules,
|
||||
sanitize_message,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import decorate_logs, set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@@ -75,39 +62,6 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
logger = init_logger("vllm.entrypoints.openai.api_server")
|
||||
|
||||
|
||||
_running_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
freeze_gc_heap()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace,
|
||||
@@ -197,313 +151,6 @@ async def build_async_engine_client_from_engine_args(
|
||||
async_llm.shutdown()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def base(request: Request) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(request)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/load")
|
||||
async def get_server_load_metrics(request: Request):
|
||||
# This endpoint returns the current server load metrics.
|
||||
# It tracks requests utilizing the GPU from the following routes:
|
||||
# - /v1/responses
|
||||
# - /v1/responses/{response_id}
|
||||
# - /v1/responses/{response_id}/cancel
|
||||
# - /v1/messages
|
||||
# - /v1/chat/completions
|
||||
# - /v1/completions
|
||||
# - /v1/audio/transcriptions
|
||||
# - /v1/audio/translations
|
||||
# - /v1/embeddings
|
||||
# - /pooling
|
||||
# - /classify
|
||||
# - /score
|
||||
# - /v1/score
|
||||
# - /rerank
|
||||
# - /v1/rerank
|
||||
# - /v2/rerank
|
||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def show_version():
|
||||
ver = {"version": VLLM_VERSION}
|
||||
return JSONResponse(content=ver)
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
try:
|
||||
with open(log_config_file) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to load log config from file %s: error %s", log_config_file, e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_uvicorn_log_config(args: Namespace) -> dict | None:
|
||||
"""
|
||||
Get the uvicorn log config based on the provided arguments.
|
||||
|
||||
Priority:
|
||||
1. If log_config_file is specified, use it
|
||||
2. If disable_access_log_for_endpoints is specified, create a config with
|
||||
the access log filter
|
||||
3. Otherwise, return None (use uvicorn defaults)
|
||||
"""
|
||||
# First, try to load from file if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
return log_config
|
||||
|
||||
# If endpoints to filter are specified, create a config with the filter
|
||||
if args.disable_access_log_for_endpoints:
|
||||
from vllm.logging_utils import create_uvicorn_log_config
|
||||
|
||||
# Parse comma-separated string into list
|
||||
excluded_paths = [
|
||||
p.strip()
|
||||
for p in args.disable_access_log_for_endpoints.split(",")
|
||||
if p.strip()
|
||||
]
|
||||
return create_uvicorn_log_config(
|
||||
excluded_paths=excluded_paths,
|
||||
log_level=args.uvicorn_log_level,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are two cases in which authentication is skipped:
|
||||
1. The HTTP method is OPTIONS.
|
||||
2. The request path doesn't start with /v1 (e.g. /health).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
|
||||
# scope["type"] can be "lifespan" or "startup" for example,
|
||||
# in which case we don't need to do anything
|
||||
return self.app(scope, receive, send)
|
||||
root_path = scope.get("root_path", "")
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
||||
return response(scope, receive, send)
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
|
||||
class XRequestIdMiddleware:
|
||||
"""
|
||||
Middleware the set's the X-Request-Id header for each response
|
||||
to a random uuid4 (hex) value if the header isn't already
|
||||
present in the request, otherwise use the provided request id.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Extract the request headers.
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
async def send_with_request_id(message: Message) -> None:
|
||||
"""
|
||||
Custom send function to mutate the response headers
|
||||
and append X-Request-Id to it.
|
||||
"""
|
||||
if message["type"] == "http.response.start":
|
||||
response_headers = MutableHeaders(raw=message["headers"])
|
||||
request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
|
||||
response_headers.append("X-Request-Id", request_id)
|
||||
await send(message)
|
||||
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||
"""Extract content from a streaming response chunk."""
|
||||
try:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.protocol import (
|
||||
CompletionStreamResponse,
|
||||
)
|
||||
|
||||
# Try using Completion types for type-safe parsing
|
||||
if chunk_data.get("object") == "chat.completion.chunk":
|
||||
chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
|
||||
if chat_response.choices and chat_response.choices[0].delta.content:
|
||||
return chat_response.choices[0].delta.content
|
||||
elif chunk_data.get("object") == "text_completion":
|
||||
completion_response = CompletionStreamResponse.model_validate(chunk_data)
|
||||
if completion_response.choices and completion_response.choices[0].text:
|
||||
return completion_response.choices[0].text
|
||||
except pydantic.ValidationError:
|
||||
# Fallback to manual parsing
|
||||
if "choices" in chunk_data and chunk_data["choices"]:
|
||||
choice = chunk_data["choices"][0]
|
||||
if "delta" in choice and choice["delta"].get("content"):
|
||||
return choice["delta"]["content"]
|
||||
elif choice.get("text"):
|
||||
return choice["text"]
|
||||
return ""
|
||||
|
||||
|
||||
class SSEDecoder:
|
||||
"""Robust Server-Sent Events decoder for streaming responses."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
self.content_buffer = []
|
||||
|
||||
def decode_chunk(self, chunk: bytes) -> list[dict]:
|
||||
"""Decode a chunk of SSE data and return parsed events."""
|
||||
import json
|
||||
|
||||
try:
|
||||
chunk_str = chunk.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Skip malformed chunks
|
||||
return []
|
||||
|
||||
self.buffer += chunk_str
|
||||
events = []
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in self.buffer:
|
||||
line, self.buffer = self.buffer.split("\n", 1)
|
||||
line = line.rstrip("\r") # Handle CRLF
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
events.append({"type": "done"})
|
||||
elif data_str:
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
events.append({"type": "data", "data": event_data})
|
||||
except json.JSONDecodeError:
|
||||
# Skip malformed JSON
|
||||
continue
|
||||
|
||||
return events
|
||||
|
||||
def extract_content(self, event_data: dict) -> str:
|
||||
"""Extract content from event data."""
|
||||
return _extract_content_from_chunk(event_data)
|
||||
|
||||
def add_content(self, content: str) -> None:
|
||||
"""Add content to the buffer."""
|
||||
if content:
|
||||
self.content_buffer.append(content)
|
||||
|
||||
def get_complete_content(self) -> str:
|
||||
"""Get the complete buffered content."""
|
||||
return "".join(self.content_buffer)
|
||||
|
||||
|
||||
def _log_streaming_response(response, response_body: list) -> None:
|
||||
"""Log streaming response with robust SSE parsing."""
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
sse_decoder = SSEDecoder()
|
||||
chunk_count = 0
|
||||
|
||||
def buffered_iterator():
|
||||
nonlocal chunk_count
|
||||
|
||||
for chunk in response_body:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
|
||||
# Parse SSE events from chunk
|
||||
events = sse_decoder.decode_chunk(chunk)
|
||||
|
||||
for event in events:
|
||||
if event["type"] == "data":
|
||||
content = sse_decoder.extract_content(event["data"])
|
||||
sse_decoder.add_content(content)
|
||||
elif event["type"] == "done":
|
||||
# Log complete content when done
|
||||
full_content = sse_decoder.get_complete_content()
|
||||
if full_content:
|
||||
# Truncate if too long
|
||||
if len(full_content) > 2048:
|
||||
full_content = full_content[:2048] + ""
|
||||
"...[truncated]"
|
||||
logger.info(
|
||||
"response_body={streaming_complete: content=%r, chunks=%d}",
|
||||
full_content,
|
||||
chunk_count,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"response_body={streaming_complete: no_content, chunks=%d}",
|
||||
chunk_count,
|
||||
)
|
||||
return
|
||||
|
||||
response.body_iterator = iterate_in_threadpool(buffered_iterator())
|
||||
logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
|
||||
|
||||
|
||||
def _log_non_streaming_response(response_body: list) -> None:
|
||||
"""Log non-streaming response."""
|
||||
try:
|
||||
decoded_body = response_body[0].decode()
|
||||
logger.info("response_body={%s}", decoded_body)
|
||||
except UnicodeDecodeError:
|
||||
logger.info("response_body={<binary_data>}")
|
||||
|
||||
|
||||
def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(
|
||||
@@ -514,7 +161,10 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.state.args = args
|
||||
app.include_router(router)
|
||||
|
||||
from vllm.entrypoints.openai.basic.api_router import register_basic_api_routers
|
||||
|
||||
register_basic_api_routers(app)
|
||||
|
||||
from vllm.entrypoints.serve import register_vllm_serve_api_routers
|
||||
|
||||
@@ -560,51 +210,18 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(_: Request, exc: HTTPException):
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(exc.detail),
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_: Request, exc: RequestValidationError):
|
||||
param = None
|
||||
errors = exc.errors()
|
||||
for error in errors:
|
||||
if "ctx" in error and "error" in error["ctx"]:
|
||||
ctx_error = error["ctx"]["error"]
|
||||
if isinstance(ctx_error, VLLMValidationError):
|
||||
param = ctx_error.parameter
|
||||
break
|
||||
|
||||
exc_str = str(exc)
|
||||
errors_str = str(errors)
|
||||
|
||||
if errors and errors_str and errors_str != exc_str:
|
||||
message = f"{exc_str} {errors_str}"
|
||||
else:
|
||||
message = exc_str
|
||||
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(message),
|
||||
type=HTTPStatus.BAD_REQUEST.phrase,
|
||||
code=HTTPStatus.BAD_REQUEST,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
app.exception_handler(HTTPException)(http_exception_handler)
|
||||
app.exception_handler(RequestValidationError)(validation_exception_handler)
|
||||
|
||||
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
|
||||
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
|
||||
from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, tokens=tokens)
|
||||
|
||||
if args.enable_request_id_headers:
|
||||
from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware
|
||||
|
||||
app.add_middleware(XRequestIdMiddleware)
|
||||
|
||||
# Add scaling middleware to check for scaling state
|
||||
@@ -616,24 +233,7 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
|
||||
"This can include sensitive information and should be "
|
||||
"avoided in production."
|
||||
)
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_response(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
response_body = [section async for section in response.body_iterator]
|
||||
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||
# Check if this is a streaming response by looking at content-type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
is_streaming = content_type == "text/event-stream; charset=utf-8"
|
||||
|
||||
# Log response body based on type
|
||||
if not response_body:
|
||||
logger.info("response_body={<empty>}")
|
||||
elif is_streaming:
|
||||
_log_streaming_response(response, response_body)
|
||||
else:
|
||||
_log_non_streaming_response(response_body)
|
||||
return response
|
||||
app.middleware("http")(log_response)
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
@@ -647,8 +247,7 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
|
||||
f"Invalid middleware {middleware}. Must be a function or a class."
|
||||
)
|
||||
|
||||
app = sagemaker_standards.bootstrap(app)
|
||||
|
||||
app = sagemaker_standards_bootstrap(app)
|
||||
return app
|
||||
|
||||
|
||||
|
||||
0
vllm/entrypoints/openai/basic/__init__.py
Normal file
0
vllm/entrypoints/openai/basic/__init__.py
Normal file
61
vllm/entrypoints/openai/basic/api_router.py
Normal file
61
vllm/entrypoints/openai/basic/api_router.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.logger import init_logger
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def base(request: Request) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(request)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/load")
|
||||
async def get_server_load_metrics(request: Request):
|
||||
# This endpoint returns the current server load metrics.
|
||||
# It tracks requests utilizing the GPU from the following routes:
|
||||
# - /v1/responses
|
||||
# - /v1/responses/{response_id}
|
||||
# - /v1/responses/{response_id}/cancel
|
||||
# - /v1/messages
|
||||
# - /v1/chat/completions
|
||||
# - /v1/completions
|
||||
# - /v1/audio/transcriptions
|
||||
# - /v1/audio/translations
|
||||
# - /v1/embeddings
|
||||
# - /pooling
|
||||
# - /classify
|
||||
# - /score
|
||||
# - /v1/score
|
||||
# - /rerank
|
||||
# - /v1/rerank
|
||||
# - /v2/rerank
|
||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def show_version():
|
||||
ver = {"version": VLLM_VERSION}
|
||||
return JSONResponse(content=ver)
|
||||
|
||||
|
||||
def register_basic_api_routers(app: FastAPI):
|
||||
app.include_router(router)
|
||||
382
vllm/entrypoints/openai/server_utils.py
Normal file
382
vllm/entrypoints/openai/server_utils.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
|
||||
import pydantic
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
|
||||
from vllm.entrypoints.utils import sanitize_message
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
|
||||
logger = init_logger("vllm.entrypoints.openai.server_utils")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
There are two cases in which authentication is skipped:
|
||||
1. The HTTP method is OPTIONS.
|
||||
2. The request path doesn't start with /v1 (e.g. /health).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
|
||||
# scope["type"] can be "lifespan" or "startup" for example,
|
||||
# in which case we don't need to do anything
|
||||
return self.app(scope, receive, send)
|
||||
root_path = scope.get("root_path", "")
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
||||
return response(scope, receive, send)
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
|
||||
class XRequestIdMiddleware:
|
||||
"""
|
||||
Middleware the set's the X-Request-Id header for each response
|
||||
to a random uuid4 (hex) value if the header isn't already
|
||||
present in the request, otherwise use the provided request id.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] not in ("http", "websocket"):
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Extract the request headers.
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
async def send_with_request_id(message: Message) -> None:
|
||||
"""
|
||||
Custom send function to mutate the response headers
|
||||
and append X-Request-Id to it.
|
||||
"""
|
||||
if message["type"] == "http.response.start":
|
||||
response_headers = MutableHeaders(raw=message["headers"])
|
||||
request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
|
||||
response_headers.append("X-Request-Id", request_id)
|
||||
await send(message)
|
||||
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
try:
|
||||
with open(log_config_file) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to load log config from file %s: error %s", log_config_file, e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_uvicorn_log_config(args: Namespace) -> dict | None:
|
||||
"""
|
||||
Get the uvicorn log config based on the provided arguments.
|
||||
|
||||
Priority:
|
||||
1. If log_config_file is specified, use it
|
||||
2. If disable_access_log_for_endpoints is specified, create a config with
|
||||
the access log filter
|
||||
3. Otherwise, return None (use uvicorn defaults)
|
||||
"""
|
||||
# First, try to load from file if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
return log_config
|
||||
|
||||
# If endpoints to filter are specified, create a config with the filter
|
||||
if args.disable_access_log_for_endpoints:
|
||||
from vllm.logging_utils import create_uvicorn_log_config
|
||||
|
||||
# Parse comma-separated string into list
|
||||
excluded_paths = [
|
||||
p.strip()
|
||||
for p in args.disable_access_log_for_endpoints.split(",")
|
||||
if p.strip()
|
||||
]
|
||||
return create_uvicorn_log_config(
|
||||
excluded_paths=excluded_paths,
|
||||
log_level=args.uvicorn_log_level,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||
"""Extract content from a streaming response chunk."""
|
||||
try:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.protocol import (
|
||||
CompletionStreamResponse,
|
||||
)
|
||||
|
||||
# Try using Completion types for type-safe parsing
|
||||
if chunk_data.get("object") == "chat.completion.chunk":
|
||||
chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
|
||||
if chat_response.choices and chat_response.choices[0].delta.content:
|
||||
return chat_response.choices[0].delta.content
|
||||
elif chunk_data.get("object") == "text_completion":
|
||||
completion_response = CompletionStreamResponse.model_validate(chunk_data)
|
||||
if completion_response.choices and completion_response.choices[0].text:
|
||||
return completion_response.choices[0].text
|
||||
except pydantic.ValidationError:
|
||||
# Fallback to manual parsing
|
||||
if "choices" in chunk_data and chunk_data["choices"]:
|
||||
choice = chunk_data["choices"][0]
|
||||
if "delta" in choice and choice["delta"].get("content"):
|
||||
return choice["delta"]["content"]
|
||||
elif choice.get("text"):
|
||||
return choice["text"]
|
||||
return ""
|
||||
|
||||
|
||||
class SSEDecoder:
|
||||
"""Robust Server-Sent Events decoder for streaming responses."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
self.content_buffer = []
|
||||
|
||||
def decode_chunk(self, chunk: bytes) -> list[dict]:
|
||||
"""Decode a chunk of SSE data and return parsed events."""
|
||||
import json
|
||||
|
||||
try:
|
||||
chunk_str = chunk.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Skip malformed chunks
|
||||
return []
|
||||
|
||||
self.buffer += chunk_str
|
||||
events = []
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in self.buffer:
|
||||
line, self.buffer = self.buffer.split("\n", 1)
|
||||
line = line.rstrip("\r") # Handle CRLF
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
events.append({"type": "done"})
|
||||
elif data_str:
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
events.append({"type": "data", "data": event_data})
|
||||
except json.JSONDecodeError:
|
||||
# Skip malformed JSON
|
||||
continue
|
||||
|
||||
return events
|
||||
|
||||
def extract_content(self, event_data: dict) -> str:
|
||||
"""Extract content from event data."""
|
||||
return _extract_content_from_chunk(event_data)
|
||||
|
||||
def add_content(self, content: str) -> None:
|
||||
"""Add content to the buffer."""
|
||||
if content:
|
||||
self.content_buffer.append(content)
|
||||
|
||||
def get_complete_content(self) -> str:
|
||||
"""Get the complete buffered content."""
|
||||
return "".join(self.content_buffer)
|
||||
|
||||
|
||||
def _log_streaming_response(response, response_body: list) -> None:
|
||||
"""Log streaming response with robust SSE parsing."""
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
sse_decoder = SSEDecoder()
|
||||
chunk_count = 0
|
||||
|
||||
def buffered_iterator():
|
||||
nonlocal chunk_count
|
||||
|
||||
for chunk in response_body:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
|
||||
# Parse SSE events from chunk
|
||||
events = sse_decoder.decode_chunk(chunk)
|
||||
|
||||
for event in events:
|
||||
if event["type"] == "data":
|
||||
content = sse_decoder.extract_content(event["data"])
|
||||
sse_decoder.add_content(content)
|
||||
elif event["type"] == "done":
|
||||
# Log complete content when done
|
||||
full_content = sse_decoder.get_complete_content()
|
||||
if full_content:
|
||||
# Truncate if too long
|
||||
if len(full_content) > 2048:
|
||||
full_content = full_content[:2048] + ""
|
||||
"...[truncated]"
|
||||
logger.info(
|
||||
"response_body={streaming_complete: content=%r, chunks=%d}",
|
||||
full_content,
|
||||
chunk_count,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"response_body={streaming_complete: no_content, chunks=%d}",
|
||||
chunk_count,
|
||||
)
|
||||
return
|
||||
|
||||
response.body_iterator = iterate_in_threadpool(buffered_iterator())
|
||||
logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
|
||||
|
||||
|
||||
def _log_non_streaming_response(response_body: list) -> None:
|
||||
"""Log non-streaming response."""
|
||||
try:
|
||||
decoded_body = response_body[0].decode()
|
||||
logger.info("response_body={%s}", decoded_body)
|
||||
except UnicodeDecodeError:
|
||||
logger.info("response_body={<binary_data>}")
|
||||
|
||||
|
||||
async def log_response(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
response_body = [section async for section in response.body_iterator]
|
||||
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||
# Check if this is a streaming response by looking at content-type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
is_streaming = content_type == "text/event-stream; charset=utf-8"
|
||||
|
||||
# Log response body based on type
|
||||
if not response_body:
|
||||
logger.info("response_body={<empty>}")
|
||||
elif is_streaming:
|
||||
_log_streaming_response(response, response_body)
|
||||
else:
|
||||
_log_non_streaming_response(response_body)
|
||||
return response
|
||||
|
||||
|
||||
async def http_exception_handler(_: Request, exc: HTTPException):
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(exc.detail),
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
|
||||
async def validation_exception_handler(_: Request, exc: RequestValidationError):
|
||||
param = None
|
||||
errors = exc.errors()
|
||||
for error in errors:
|
||||
if "ctx" in error and "error" in error["ctx"]:
|
||||
ctx_error = error["ctx"]["error"]
|
||||
if isinstance(ctx_error, VLLMValidationError):
|
||||
param = ctx_error.parameter
|
||||
break
|
||||
|
||||
exc_str = str(exc)
|
||||
errors_str = str(errors)
|
||||
|
||||
if errors and errors_str and errors_str != exc_str:
|
||||
message = f"{exc_str} {errors_str}"
|
||||
else:
|
||||
message = exc_str
|
||||
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=sanitize_message(message),
|
||||
type=HTTPStatus.BAD_REQUEST.phrase,
|
||||
code=HTTPStatus.BAD_REQUEST,
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
_running_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
freeze_gc_heap()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
@@ -10,7 +10,7 @@ import pydantic
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm.entrypoints.openai.api_server import base
|
||||
from vllm.entrypoints.openai.basic.api_router import base
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
@@ -158,3 +158,7 @@ def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]):
|
||||
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
def sagemaker_standards_bootstrap(app: FastAPI) -> FastAPI:
|
||||
return sagemaker_standards.bootstrap(app)
|
||||
|
||||
Reference in New Issue
Block a user