[Frontend] Cleanup api server (#33158)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2026-01-27 23:18:10 +08:00
committed by GitHub
parent 5ec44056f7
commit 7cbbca9aaa
5 changed files with 471 additions and 425 deletions

View File

@@ -1,34 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import hashlib
import importlib
import inspect
import json
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
import secrets
import signal
import socket
import tempfile
import uuid
from argparse import Namespace
from collections.abc import AsyncIterator, Awaitable
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Any
import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic
import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.datastructures import State
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -37,15 +26,16 @@ from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import (
OpenAIServingModels,
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
get_uvicorn_log_config,
http_exception_handler,
lifespan,
log_response,
validation_exception_handler,
)
from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware,
)
@@ -55,16 +45,13 @@ from vllm.entrypoints.utils import (
log_non_default_args,
log_version_and_model,
process_lora_modules,
sanitize_message,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.version import __version__ as VLLM_VERSION
@@ -75,39 +62,6 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
logger = init_logger("vllm.entrypoints.openai.api_server")
_running_tasks: set[asyncio.Task] = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager
async def build_async_engine_client(
args: Namespace,
@@ -197,313 +151,6 @@ async def build_async_engine_client_from_engine_args(
async_llm.shutdown()
router = APIRouter()
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/load")
async def get_server_load_metrics(request: Request):
# This endpoint returns the current server load metrics.
# It tracks requests utilizing the GPU from the following routes:
# - /v1/responses
# - /v1/responses/{response_id}
# - /v1/responses/{response_id}/cancel
# - /v1/messages
# - /v1/chat/completions
# - /v1/completions
# - /v1/audio/transcriptions
# - /v1/audio/translations
# - /v1/embeddings
# - /pooling
# - /classify
# - /score
# - /v1/score
# - /rerank
# - /v1/rerank
# - /v2/rerank
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file:
return None
try:
with open(log_config_file) as f:
return json.load(f)
except Exception as e:
logger.warning(
"Failed to load log config from file %s: error %s", log_config_file, e
)
return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes
-----
There are two cases in which authentication is skipped:
1. The HTTP method is OPTIONS.
2. The request path doesn't start with /v1 (e.g. /health).
"""
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
# scope["type"] can be "lifespan" or "startup" for example,
# in which case we don't need to do anything
return self.app(scope, receive, send)
root_path = scope.get("root_path", "")
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and not self.verify_token(headers):
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
return response(scope, receive, send)
return self.app(scope, receive, send)
class XRequestIdMiddleware:
"""
Middleware the set's the X-Request-Id header for each response
to a random uuid4 (hex) value if the header isn't already
present in the request, otherwise use the provided request id.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket"):
return self.app(scope, receive, send)
# Extract the request headers.
request_headers = Headers(scope=scope)
async def send_with_request_id(message: Message) -> None:
"""
Custom send function to mutate the response headers
and append X-Request-Id to it.
"""
if message["type"] == "http.response.start":
response_headers = MutableHeaders(raw=message["headers"])
request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
response_headers.append("X-Request-Id", request_id)
await send(message)
return self.app(scope, receive, send_with_request_id)
def _extract_content_from_chunk(chunk_data: dict) -> str:
"""Extract content from a streaming response chunk."""
try:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionStreamResponse,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionStreamResponse,
)
# Try using Completion types for type-safe parsing
if chunk_data.get("object") == "chat.completion.chunk":
chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
if chat_response.choices and chat_response.choices[0].delta.content:
return chat_response.choices[0].delta.content
elif chunk_data.get("object") == "text_completion":
completion_response = CompletionStreamResponse.model_validate(chunk_data)
if completion_response.choices and completion_response.choices[0].text:
return completion_response.choices[0].text
except pydantic.ValidationError:
# Fallback to manual parsing
if "choices" in chunk_data and chunk_data["choices"]:
choice = chunk_data["choices"][0]
if "delta" in choice and choice["delta"].get("content"):
return choice["delta"]["content"]
elif choice.get("text"):
return choice["text"]
return ""
class SSEDecoder:
"""Robust Server-Sent Events decoder for streaming responses."""
def __init__(self):
self.buffer = ""
self.content_buffer = []
def decode_chunk(self, chunk: bytes) -> list[dict]:
"""Decode a chunk of SSE data and return parsed events."""
import json
try:
chunk_str = chunk.decode("utf-8")
except UnicodeDecodeError:
# Skip malformed chunks
return []
self.buffer += chunk_str
events = []
# Process complete lines
while "\n" in self.buffer:
line, self.buffer = self.buffer.split("\n", 1)
line = line.rstrip("\r") # Handle CRLF
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
events.append({"type": "done"})
elif data_str:
try:
event_data = json.loads(data_str)
events.append({"type": "data", "data": event_data})
except json.JSONDecodeError:
# Skip malformed JSON
continue
return events
def extract_content(self, event_data: dict) -> str:
"""Extract content from event data."""
return _extract_content_from_chunk(event_data)
def add_content(self, content: str) -> None:
"""Add content to the buffer."""
if content:
self.content_buffer.append(content)
def get_complete_content(self) -> str:
"""Get the complete buffered content."""
return "".join(self.content_buffer)
def _log_streaming_response(response, response_body: list) -> None:
"""Log streaming response with robust SSE parsing."""
from starlette.concurrency import iterate_in_threadpool
sse_decoder = SSEDecoder()
chunk_count = 0
def buffered_iterator():
nonlocal chunk_count
for chunk in response_body:
chunk_count += 1
yield chunk
# Parse SSE events from chunk
events = sse_decoder.decode_chunk(chunk)
for event in events:
if event["type"] == "data":
content = sse_decoder.extract_content(event["data"])
sse_decoder.add_content(content)
elif event["type"] == "done":
# Log complete content when done
full_content = sse_decoder.get_complete_content()
if full_content:
# Truncate if too long
if len(full_content) > 2048:
full_content = full_content[:2048] + ""
"...[truncated]"
logger.info(
"response_body={streaming_complete: content=%r, chunks=%d}",
full_content,
chunk_count,
)
else:
logger.info(
"response_body={streaming_complete: no_content, chunks=%d}",
chunk_count,
)
return
response.body_iterator = iterate_in_threadpool(buffered_iterator())
logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
def _log_non_streaming_response(response_body: list) -> None:
"""Log non-streaming response."""
try:
decoded_body = response_body[0].decode()
logger.info("response_body={%s}", decoded_body)
except UnicodeDecodeError:
logger.info("response_body={<binary_data>}")
def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(
@@ -514,7 +161,10 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
else:
app = FastAPI(lifespan=lifespan)
app.state.args = args
app.include_router(router)
from vllm.entrypoints.openai.basic.api_router import register_basic_api_routers
register_basic_api_routers(app)
from vllm.entrypoints.serve import register_vllm_serve_api_routers
@@ -560,51 +210,18 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
allow_headers=args.allowed_headers,
)
@app.exception_handler(HTTPException)
async def http_exception_handler(_: Request, exc: HTTPException):
err = ErrorResponse(
error=ErrorInfo(
message=sanitize_message(exc.detail),
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: Request, exc: RequestValidationError):
param = None
errors = exc.errors()
for error in errors:
if "ctx" in error and "error" in error["ctx"]:
ctx_error = error["ctx"]["error"]
if isinstance(ctx_error, VLLMValidationError):
param = ctx_error.parameter
break
exc_str = str(exc)
errors_str = str(errors)
if errors and errors_str and errors_str != exc_str:
message = f"{exc_str} {errors_str}"
else:
message = exc_str
err = ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST,
param=param,
)
)
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
app.exception_handler(HTTPException)(http_exception_handler)
app.exception_handler(RequestValidationError)(validation_exception_handler)
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware
app.add_middleware(AuthenticationMiddleware, tokens=tokens)
if args.enable_request_id_headers:
from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware
app.add_middleware(XRequestIdMiddleware)
# Add scaling middleware to check for scaling state
@@ -616,24 +233,7 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
"This can include sensitive information and should be "
"avoided in production."
)
@app.middleware("http")
async def log_response(request: Request, call_next):
response = await call_next(request)
response_body = [section async for section in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body))
# Check if this is a streaming response by looking at content-type
content_type = response.headers.get("content-type", "")
is_streaming = content_type == "text/event-stream; charset=utf-8"
# Log response body based on type
if not response_body:
logger.info("response_body={<empty>}")
elif is_streaming:
_log_streaming_response(response, response_body)
else:
_log_non_streaming_response(response_body)
return response
app.middleware("http")(log_response)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
@@ -647,8 +247,7 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
f"Invalid middleware {middleware}. Must be a function or a class."
)
app = sagemaker_standards.bootstrap(app)
app = sagemaker_standards_bootstrap(app)
return app

View File

@@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import JSONResponse
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.logger import init_logger
from vllm.version import __version__ as VLLM_VERSION
router = APIRouter()
logger = init_logger(__name__)
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/load")
async def get_server_load_metrics(request: Request):
# This endpoint returns the current server load metrics.
# It tracks requests utilizing the GPU from the following routes:
# - /v1/responses
# - /v1/responses/{response_id}
# - /v1/responses/{response_id}/cancel
# - /v1/messages
# - /v1/chat/completions
# - /v1/completions
# - /v1/audio/transcriptions
# - /v1/audio/translations
# - /v1/embeddings
# - /pooling
# - /classify
# - /score
# - /v1/score
# - /rerank
# - /v1/rerank
# - /v2/rerank
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
def register_basic_api_routers(app: FastAPI):
app.include_router(router)

View File

@@ -0,0 +1,382 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import hashlib
import json
import secrets
import uuid
from argparse import Namespace
from collections.abc import Awaitable
from contextlib import asynccontextmanager
from http import HTTPStatus
import pydantic
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from vllm import envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
from vllm.entrypoints.utils import sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.utils.gc_utils import freeze_gc_heap
logger = init_logger("vllm.entrypoints.openai.server_utils")
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes
-----
There are two cases in which authentication is skipped:
1. The HTTP method is OPTIONS.
2. The request path doesn't start with /v1 (e.g. /health).
"""
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app
self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
# scope["type"] can be "lifespan" or "startup" for example,
# in which case we don't need to do anything
return self.app(scope, receive, send)
root_path = scope.get("root_path", "")
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and not self.verify_token(headers):
response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
return response(scope, receive, send)
return self.app(scope, receive, send)
class XRequestIdMiddleware:
"""
Middleware the set's the X-Request-Id header for each response
to a random uuid4 (hex) value if the header isn't already
present in the request, otherwise use the provided request id.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] not in ("http", "websocket"):
return self.app(scope, receive, send)
# Extract the request headers.
request_headers = Headers(scope=scope)
async def send_with_request_id(message: Message) -> None:
"""
Custom send function to mutate the response headers
and append X-Request-Id to it.
"""
if message["type"] == "http.response.start":
response_headers = MutableHeaders(raw=message["headers"])
request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
response_headers.append("X-Request-Id", request_id)
await send(message)
return self.app(scope, receive, send_with_request_id)
def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file:
return None
try:
with open(log_config_file) as f:
return json.load(f)
except Exception as e:
logger.warning(
"Failed to load log config from file %s: error %s", log_config_file, e
)
return None
def get_uvicorn_log_config(args: Namespace) -> dict | None:
"""
Get the uvicorn log config based on the provided arguments.
Priority:
1. If log_config_file is specified, use it
2. If disable_access_log_for_endpoints is specified, create a config with
the access log filter
3. Otherwise, return None (use uvicorn defaults)
"""
# First, try to load from file if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
return log_config
# If endpoints to filter are specified, create a config with the filter
if args.disable_access_log_for_endpoints:
from vllm.logging_utils import create_uvicorn_log_config
# Parse comma-separated string into list
excluded_paths = [
p.strip()
for p in args.disable_access_log_for_endpoints.split(",")
if p.strip()
]
return create_uvicorn_log_config(
excluded_paths=excluded_paths,
log_level=args.uvicorn_log_level,
)
return None
def _extract_content_from_chunk(chunk_data: dict) -> str:
"""Extract content from a streaming response chunk."""
try:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionStreamResponse,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionStreamResponse,
)
# Try using Completion types for type-safe parsing
if chunk_data.get("object") == "chat.completion.chunk":
chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
if chat_response.choices and chat_response.choices[0].delta.content:
return chat_response.choices[0].delta.content
elif chunk_data.get("object") == "text_completion":
completion_response = CompletionStreamResponse.model_validate(chunk_data)
if completion_response.choices and completion_response.choices[0].text:
return completion_response.choices[0].text
except pydantic.ValidationError:
# Fallback to manual parsing
if "choices" in chunk_data and chunk_data["choices"]:
choice = chunk_data["choices"][0]
if "delta" in choice and choice["delta"].get("content"):
return choice["delta"]["content"]
elif choice.get("text"):
return choice["text"]
return ""
class SSEDecoder:
"""Robust Server-Sent Events decoder for streaming responses."""
def __init__(self):
self.buffer = ""
self.content_buffer = []
def decode_chunk(self, chunk: bytes) -> list[dict]:
"""Decode a chunk of SSE data and return parsed events."""
import json
try:
chunk_str = chunk.decode("utf-8")
except UnicodeDecodeError:
# Skip malformed chunks
return []
self.buffer += chunk_str
events = []
# Process complete lines
while "\n" in self.buffer:
line, self.buffer = self.buffer.split("\n", 1)
line = line.rstrip("\r") # Handle CRLF
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
events.append({"type": "done"})
elif data_str:
try:
event_data = json.loads(data_str)
events.append({"type": "data", "data": event_data})
except json.JSONDecodeError:
# Skip malformed JSON
continue
return events
def extract_content(self, event_data: dict) -> str:
"""Extract content from event data."""
return _extract_content_from_chunk(event_data)
def add_content(self, content: str) -> None:
"""Add content to the buffer."""
if content:
self.content_buffer.append(content)
def get_complete_content(self) -> str:
"""Get the complete buffered content."""
return "".join(self.content_buffer)
def _log_streaming_response(response, response_body: list) -> None:
"""Log streaming response with robust SSE parsing."""
from starlette.concurrency import iterate_in_threadpool
sse_decoder = SSEDecoder()
chunk_count = 0
def buffered_iterator():
nonlocal chunk_count
for chunk in response_body:
chunk_count += 1
yield chunk
# Parse SSE events from chunk
events = sse_decoder.decode_chunk(chunk)
for event in events:
if event["type"] == "data":
content = sse_decoder.extract_content(event["data"])
sse_decoder.add_content(content)
elif event["type"] == "done":
# Log complete content when done
full_content = sse_decoder.get_complete_content()
if full_content:
# Truncate if too long
if len(full_content) > 2048:
full_content = full_content[:2048] + ""
"...[truncated]"
logger.info(
"response_body={streaming_complete: content=%r, chunks=%d}",
full_content,
chunk_count,
)
else:
logger.info(
"response_body={streaming_complete: no_content, chunks=%d}",
chunk_count,
)
return
response.body_iterator = iterate_in_threadpool(buffered_iterator())
logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
def _log_non_streaming_response(response_body: list) -> None:
"""Log non-streaming response."""
try:
decoded_body = response_body[0].decode()
logger.info("response_body={%s}", decoded_body)
except UnicodeDecodeError:
logger.info("response_body={<binary_data>}")
async def log_response(request: Request, call_next):
response = await call_next(request)
response_body = [section async for section in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body))
# Check if this is a streaming response by looking at content-type
content_type = response.headers.get("content-type", "")
is_streaming = content_type == "text/event-stream; charset=utf-8"
# Log response body based on type
if not response_body:
logger.info("response_body={<empty>}")
elif is_streaming:
_log_streaming_response(response, response_body)
else:
_log_non_streaming_response(response_body)
return response
async def http_exception_handler(_: Request, exc: HTTPException):
err = ErrorResponse(
error=ErrorInfo(
message=sanitize_message(exc.detail),
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
async def validation_exception_handler(_: Request, exc: RequestValidationError):
param = None
errors = exc.errors()
for error in errors:
if "ctx" in error and "error" in error["ctx"]:
ctx_error = error["ctx"]["error"]
if isinstance(ctx_error, VLLMValidationError):
param = ctx_error.parameter
break
exc_str = str(exc)
errors_str = str(errors)
if errors and errors_str and errors_str != exc_str:
message = f"{exc_str} {errors_str}"
else:
message = exc_str
err = ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST,
param=param,
)
)
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
_running_tasks: set[asyncio.Task] = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state

View File

@@ -10,7 +10,7 @@ import pydantic
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.api_server import base
from vllm.entrypoints.openai.basic.api_router import base
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
@@ -158,3 +158,7 @@ def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]):
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
app.include_router(router)
def sagemaker_standards_bootstrap(app: FastAPI) -> FastAPI:
return sagemaker_standards.bootstrap(app)