diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 85aec7a88..9639ba28e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,34 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -import hashlib import importlib import inspect -import json import multiprocessing import multiprocessing.forkserver as forkserver import os -import secrets import signal import socket import tempfile -import uuid from argparse import Namespace -from collections.abc import AsyncIterator, Awaitable +from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from http import HTTPStatus from typing import Any -import model_hosting_container_standards.sagemaker as sagemaker_standards -import pydantic import uvloop -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from starlette.concurrency import iterate_in_threadpool -from starlette.datastructures import URL, Headers, MutableHeaders, State -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.datastructures import State import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs @@ -37,15 +26,16 @@ from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args -from vllm.entrypoints.openai.engine.protocol import ( - ErrorInfo, - ErrorResponse, -) -from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.protocol import BaseModelPath -from vllm.entrypoints.openai.models.serving import ( - OpenAIServingModels, +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.openai.server_utils import ( + get_uvicorn_log_config, + http_exception_handler, + lifespan, + log_response, + validation_exception_handler, ) +from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap from vllm.entrypoints.serve.elastic_ep.middleware import ( ScalingMiddleware, ) @@ -55,16 +45,13 @@ from vllm.entrypoints.utils import ( log_non_default_args, log_version_and_model, process_lora_modules, - sanitize_message, ) -from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tool_parsers import ToolParserManager from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.version import __version__ as VLLM_VERSION @@ -75,39 +62,6 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory logger = init_logger("vllm.entrypoints.openai.api_server") -_running_tasks: set[asyncio.Task] = set() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - try: - if app.state.log_stats: - engine_client: EngineClient = app.state.engine_client - - async def _force_log(): - while True: - await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL) - await engine_client.do_log_stats() - - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) - else: - task = None - - # Mark the startup heap as static so that it's ignored by GC. - # Reduces pause times of oldest generation collections. - freeze_gc_heap() - try: - yield - finally: - if task is not None: - task.cancel() - finally: - # Ensure app state including engine ref is gc'd - del app.state - - @asynccontextmanager async def build_async_engine_client( args: Namespace, @@ -197,313 +151,6 @@ async def build_async_engine_client_from_engine_args( async_llm.shutdown() -router = APIRouter() - - -def base(request: Request) -> OpenAIServing: - # Reuse the existing instance - return tokenization(request) - - -def tokenization(request: Request) -> OpenAIServingTokenization: - return request.app.state.openai_serving_tokenization - - -def engine_client(request: Request) -> EngineClient: - return request.app.state.engine_client - - -@router.get("/load") -async def get_server_load_metrics(request: Request): - # This endpoint returns the current server load metrics. - # It tracks requests utilizing the GPU from the following routes: - # - /v1/responses - # - /v1/responses/{response_id} - # - /v1/responses/{response_id}/cancel - # - /v1/messages - # - /v1/chat/completions - # - /v1/completions - # - /v1/audio/transcriptions - # - /v1/audio/translations - # - /v1/embeddings - # - /pooling - # - /classify - # - /score - # - /v1/score - # - /rerank - # - /v1/rerank - # - /v2/rerank - return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) - - -@router.get("/version") -async def show_version(): - ver = {"version": VLLM_VERSION} - return JSONResponse(content=ver) - - -def load_log_config(log_config_file: str | None) -> dict | None: - if not log_config_file: - return None - try: - with open(log_config_file) as f: - return json.load(f) - except Exception as e: - logger.warning( - "Failed to load log config from file %s: error %s", log_config_file, e - ) - return None - - -def get_uvicorn_log_config(args: Namespace) -> dict | None: - """ - Get the uvicorn log config based on the provided arguments. - - Priority: - 1. If log_config_file is specified, use it - 2. If disable_access_log_for_endpoints is specified, create a config with - the access log filter - 3. Otherwise, return None (use uvicorn defaults) - """ - # First, try to load from file if specified - log_config = load_log_config(args.log_config_file) - if log_config is not None: - return log_config - - # If endpoints to filter are specified, create a config with the filter - if args.disable_access_log_for_endpoints: - from vllm.logging_utils import create_uvicorn_log_config - - # Parse comma-separated string into list - excluded_paths = [ - p.strip() - for p in args.disable_access_log_for_endpoints.split(",") - if p.strip() - ] - return create_uvicorn_log_config( - excluded_paths=excluded_paths, - log_level=args.uvicorn_log_level, - ) - - return None - - -class AuthenticationMiddleware: - """ - Pure ASGI middleware that authenticates each request by checking - if the Authorization Bearer token exists and equals anyof "{api_key}". - - Notes - ----- - There are two cases in which authentication is skipped: - 1. The HTTP method is OPTIONS. - 2. The request path doesn't start with /v1 (e.g. /health). - """ - - def __init__(self, app: ASGIApp, tokens: list[str]) -> None: - self.app = app - self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens] - - def verify_token(self, headers: Headers) -> bool: - authorization_header_value = headers.get("Authorization") - if not authorization_header_value: - return False - - scheme, _, param = authorization_header_value.partition(" ") - if scheme.lower() != "bearer": - return False - - param_hash = hashlib.sha256(param.encode("utf-8")).digest() - - token_match = False - for token_hash in self.api_tokens: - token_match |= secrets.compare_digest(param_hash, token_hash) - - return token_match - - def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: - if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS": - # scope["type"] can be "lifespan" or "startup" for example, - # in which case we don't need to do anything - return self.app(scope, receive, send) - root_path = scope.get("root_path", "") - url_path = URL(scope=scope).path.removeprefix(root_path) - headers = Headers(scope=scope) - # Type narrow to satisfy mypy. - if url_path.startswith("/v1") and not self.verify_token(headers): - response = JSONResponse(content={"error": "Unauthorized"}, status_code=401) - return response(scope, receive, send) - return self.app(scope, receive, send) - - -class XRequestIdMiddleware: - """ - Middleware the set's the X-Request-Id header for each response - to a random uuid4 (hex) value if the header isn't already - present in the request, otherwise use the provided request id. - """ - - def __init__(self, app: ASGIApp) -> None: - self.app = app - - def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: - if scope["type"] not in ("http", "websocket"): - return self.app(scope, receive, send) - - # Extract the request headers. - request_headers = Headers(scope=scope) - - async def send_with_request_id(message: Message) -> None: - """ - Custom send function to mutate the response headers - and append X-Request-Id to it. - """ - if message["type"] == "http.response.start": - response_headers = MutableHeaders(raw=message["headers"]) - request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex) - response_headers.append("X-Request-Id", request_id) - await send(message) - - return self.app(scope, receive, send_with_request_id) - - -def _extract_content_from_chunk(chunk_data: dict) -> str: - """Extract content from a streaming response chunk.""" - try: - from vllm.entrypoints.openai.chat_completion.protocol import ( - ChatCompletionStreamResponse, - ) - from vllm.entrypoints.openai.completion.protocol import ( - CompletionStreamResponse, - ) - - # Try using Completion types for type-safe parsing - if chunk_data.get("object") == "chat.completion.chunk": - chat_response = ChatCompletionStreamResponse.model_validate(chunk_data) - if chat_response.choices and chat_response.choices[0].delta.content: - return chat_response.choices[0].delta.content - elif chunk_data.get("object") == "text_completion": - completion_response = CompletionStreamResponse.model_validate(chunk_data) - if completion_response.choices and completion_response.choices[0].text: - return completion_response.choices[0].text - except pydantic.ValidationError: - # Fallback to manual parsing - if "choices" in chunk_data and chunk_data["choices"]: - choice = chunk_data["choices"][0] - if "delta" in choice and choice["delta"].get("content"): - return choice["delta"]["content"] - elif choice.get("text"): - return choice["text"] - return "" - - -class SSEDecoder: - """Robust Server-Sent Events decoder for streaming responses.""" - - def __init__(self): - self.buffer = "" - self.content_buffer = [] - - def decode_chunk(self, chunk: bytes) -> list[dict]: - """Decode a chunk of SSE data and return parsed events.""" - import json - - try: - chunk_str = chunk.decode("utf-8") - except UnicodeDecodeError: - # Skip malformed chunks - return [] - - self.buffer += chunk_str - events = [] - - # Process complete lines - while "\n" in self.buffer: - line, self.buffer = self.buffer.split("\n", 1) - line = line.rstrip("\r") # Handle CRLF - - if line.startswith("data: "): - data_str = line[6:].strip() - if data_str == "[DONE]": - events.append({"type": "done"}) - elif data_str: - try: - event_data = json.loads(data_str) - events.append({"type": "data", "data": event_data}) - except json.JSONDecodeError: - # Skip malformed JSON - continue - - return events - - def extract_content(self, event_data: dict) -> str: - """Extract content from event data.""" - return _extract_content_from_chunk(event_data) - - def add_content(self, content: str) -> None: - """Add content to the buffer.""" - if content: - self.content_buffer.append(content) - - def get_complete_content(self) -> str: - """Get the complete buffered content.""" - return "".join(self.content_buffer) - - -def _log_streaming_response(response, response_body: list) -> None: - """Log streaming response with robust SSE parsing.""" - from starlette.concurrency import iterate_in_threadpool - - sse_decoder = SSEDecoder() - chunk_count = 0 - - def buffered_iterator(): - nonlocal chunk_count - - for chunk in response_body: - chunk_count += 1 - yield chunk - - # Parse SSE events from chunk - events = sse_decoder.decode_chunk(chunk) - - for event in events: - if event["type"] == "data": - content = sse_decoder.extract_content(event["data"]) - sse_decoder.add_content(content) - elif event["type"] == "done": - # Log complete content when done - full_content = sse_decoder.get_complete_content() - if full_content: - # Truncate if too long - if len(full_content) > 2048: - full_content = full_content[:2048] + "" - "...[truncated]" - logger.info( - "response_body={streaming_complete: content=%r, chunks=%d}", - full_content, - chunk_count, - ) - else: - logger.info( - "response_body={streaming_complete: no_content, chunks=%d}", - chunk_count, - ) - return - - response.body_iterator = iterate_in_threadpool(buffered_iterator()) - logger.info("response_body={streaming_started: chunks=%d}", len(response_body)) - - -def _log_non_streaming_response(response_body: list) -> None: - """Log non-streaming response.""" - try: - decoded_body = response_body[0].decode() - logger.info("response_body={%s}", decoded_body) - except UnicodeDecodeError: - logger.info("response_body={}") - - def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI: if args.disable_fastapi_docs: app = FastAPI( @@ -514,7 +161,10 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> else: app = FastAPI(lifespan=lifespan) app.state.args = args - app.include_router(router) + + from vllm.entrypoints.openai.basic.api_router import register_basic_api_routers + + register_basic_api_routers(app) from vllm.entrypoints.serve import register_vllm_serve_api_routers @@ -560,51 +210,18 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> allow_headers=args.allowed_headers, ) - @app.exception_handler(HTTPException) - async def http_exception_handler(_: Request, exc: HTTPException): - err = ErrorResponse( - error=ErrorInfo( - message=sanitize_message(exc.detail), - type=HTTPStatus(exc.status_code).phrase, - code=exc.status_code, - ) - ) - return JSONResponse(err.model_dump(), status_code=exc.status_code) - - @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_: Request, exc: RequestValidationError): - param = None - errors = exc.errors() - for error in errors: - if "ctx" in error and "error" in error["ctx"]: - ctx_error = error["ctx"]["error"] - if isinstance(ctx_error, VLLMValidationError): - param = ctx_error.parameter - break - - exc_str = str(exc) - errors_str = str(errors) - - if errors and errors_str and errors_str != exc_str: - message = f"{exc_str} {errors_str}" - else: - message = exc_str - - err = ErrorResponse( - error=ErrorInfo( - message=sanitize_message(message), - type=HTTPStatus.BAD_REQUEST.phrase, - code=HTTPStatus.BAD_REQUEST, - param=param, - ) - ) - return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) + app.exception_handler(HTTPException)(http_exception_handler) + app.exception_handler(RequestValidationError)(validation_exception_handler) # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]: + from vllm.entrypoints.openai.server_utils import AuthenticationMiddleware + app.add_middleware(AuthenticationMiddleware, tokens=tokens) if args.enable_request_id_headers: + from vllm.entrypoints.openai.server_utils import XRequestIdMiddleware + app.add_middleware(XRequestIdMiddleware) # Add scaling middleware to check for scaling state @@ -616,24 +233,7 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> "This can include sensitive information and should be " "avoided in production." ) - - @app.middleware("http") - async def log_response(request: Request, call_next): - response = await call_next(request) - response_body = [section async for section in response.body_iterator] - response.body_iterator = iterate_in_threadpool(iter(response_body)) - # Check if this is a streaming response by looking at content-type - content_type = response.headers.get("content-type", "") - is_streaming = content_type == "text/event-stream; charset=utf-8" - - # Log response body based on type - if not response_body: - logger.info("response_body={}") - elif is_streaming: - _log_streaming_response(response, response_body) - else: - _log_non_streaming_response(response_body) - return response + app.middleware("http")(log_response) for middleware in args.middleware: module_path, object_name = middleware.rsplit(".", 1) @@ -647,8 +247,7 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> f"Invalid middleware {middleware}. Must be a function or a class." ) - app = sagemaker_standards.bootstrap(app) - + app = sagemaker_standards_bootstrap(app) return app diff --git a/vllm/entrypoints/openai/basic/__init__.py b/vllm/entrypoints/openai/basic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/entrypoints/openai/basic/api_router.py b/vllm/entrypoints/openai/basic/api_router.py new file mode 100644 index 000000000..3378d914a --- /dev/null +++ b/vllm/entrypoints/openai/basic/api_router.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.logger import init_logger +from vllm.version import __version__ as VLLM_VERSION + +router = APIRouter() + +logger = init_logger(__name__) + + +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/load") +async def get_server_load_metrics(request: Request): + # This endpoint returns the current server load metrics. + # It tracks requests utilizing the GPU from the following routes: + # - /v1/responses + # - /v1/responses/{response_id} + # - /v1/responses/{response_id}/cancel + # - /v1/messages + # - /v1/chat/completions + # - /v1/completions + # - /v1/audio/transcriptions + # - /v1/audio/translations + # - /v1/embeddings + # - /pooling + # - /classify + # - /score + # - /v1/score + # - /rerank + # - /v1/rerank + # - /v2/rerank + return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) + + +@router.get("/version") +async def show_version(): + ver = {"version": VLLM_VERSION} + return JSONResponse(content=ver) + + +def register_basic_api_routers(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/openai/server_utils.py b/vllm/entrypoints/openai/server_utils.py new file mode 100644 index 000000000..12768cb6f --- /dev/null +++ b/vllm/entrypoints/openai/server_utils.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import hashlib +import json +import secrets +import uuid +from argparse import Namespace +from collections.abc import Awaitable +from contextlib import asynccontextmanager +from http import HTTPStatus + +import pydantic +from fastapi import FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from starlette.concurrency import iterate_in_threadpool +from starlette.datastructures import URL, Headers, MutableHeaders +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from vllm import envs +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse +from vllm.entrypoints.utils import sanitize_message +from vllm.exceptions import VLLMValidationError +from vllm.logger import init_logger +from vllm.utils.gc_utils import freeze_gc_heap + +logger = init_logger("vllm.entrypoints.openai.server_utils") + + +class AuthenticationMiddleware: + """ + Pure ASGI middleware that authenticates each request by checking + if the Authorization Bearer token exists and equals anyof "{api_key}". + + Notes + ----- + There are two cases in which authentication is skipped: + 1. The HTTP method is OPTIONS. + 2. The request path doesn't start with /v1 (e.g. /health). + """ + + def __init__(self, app: ASGIApp, tokens: list[str]) -> None: + self.app = app + self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens] + + def verify_token(self, headers: Headers) -> bool: + authorization_header_value = headers.get("Authorization") + if not authorization_header_value: + return False + + scheme, _, param = authorization_header_value.partition(" ") + if scheme.lower() != "bearer": + return False + + param_hash = hashlib.sha256(param.encode("utf-8")).digest() + + token_match = False + for token_hash in self.api_tokens: + token_match |= secrets.compare_digest(param_hash, token_hash) + + return token_match + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS": + # scope["type"] can be "lifespan" or "startup" for example, + # in which case we don't need to do anything + return self.app(scope, receive, send) + root_path = scope.get("root_path", "") + url_path = URL(scope=scope).path.removeprefix(root_path) + headers = Headers(scope=scope) + # Type narrow to satisfy mypy. + if url_path.startswith("/v1") and not self.verify_token(headers): + response = JSONResponse(content={"error": "Unauthorized"}, status_code=401) + return response(scope, receive, send) + return self.app(scope, receive, send) + + +class XRequestIdMiddleware: + """ + Middleware the set's the X-Request-Id header for each response + to a random uuid4 (hex) value if the header isn't already + present in the request, otherwise use the provided request id. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", "websocket"): + return self.app(scope, receive, send) + + # Extract the request headers. + request_headers = Headers(scope=scope) + + async def send_with_request_id(message: Message) -> None: + """ + Custom send function to mutate the response headers + and append X-Request-Id to it. + """ + if message["type"] == "http.response.start": + response_headers = MutableHeaders(raw=message["headers"]) + request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex) + response_headers.append("X-Request-Id", request_id) + await send(message) + + return self.app(scope, receive, send_with_request_id) + + +def load_log_config(log_config_file: str | None) -> dict | None: + if not log_config_file: + return None + try: + with open(log_config_file) as f: + return json.load(f) + except Exception as e: + logger.warning( + "Failed to load log config from file %s: error %s", log_config_file, e + ) + return None + + +def get_uvicorn_log_config(args: Namespace) -> dict | None: + """ + Get the uvicorn log config based on the provided arguments. + + Priority: + 1. If log_config_file is specified, use it + 2. If disable_access_log_for_endpoints is specified, create a config with + the access log filter + 3. Otherwise, return None (use uvicorn defaults) + """ + # First, try to load from file if specified + log_config = load_log_config(args.log_config_file) + if log_config is not None: + return log_config + + # If endpoints to filter are specified, create a config with the filter + if args.disable_access_log_for_endpoints: + from vllm.logging_utils import create_uvicorn_log_config + + # Parse comma-separated string into list + excluded_paths = [ + p.strip() + for p in args.disable_access_log_for_endpoints.split(",") + if p.strip() + ] + return create_uvicorn_log_config( + excluded_paths=excluded_paths, + log_level=args.uvicorn_log_level, + ) + + return None + + +def _extract_content_from_chunk(chunk_data: dict) -> str: + """Extract content from a streaming response chunk.""" + try: + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionStreamResponse, + ) + from vllm.entrypoints.openai.completion.protocol import ( + CompletionStreamResponse, + ) + + # Try using Completion types for type-safe parsing + if chunk_data.get("object") == "chat.completion.chunk": + chat_response = ChatCompletionStreamResponse.model_validate(chunk_data) + if chat_response.choices and chat_response.choices[0].delta.content: + return chat_response.choices[0].delta.content + elif chunk_data.get("object") == "text_completion": + completion_response = CompletionStreamResponse.model_validate(chunk_data) + if completion_response.choices and completion_response.choices[0].text: + return completion_response.choices[0].text + except pydantic.ValidationError: + # Fallback to manual parsing + if "choices" in chunk_data and chunk_data["choices"]: + choice = chunk_data["choices"][0] + if "delta" in choice and choice["delta"].get("content"): + return choice["delta"]["content"] + elif choice.get("text"): + return choice["text"] + return "" + + +class SSEDecoder: + """Robust Server-Sent Events decoder for streaming responses.""" + + def __init__(self): + self.buffer = "" + self.content_buffer = [] + + def decode_chunk(self, chunk: bytes) -> list[dict]: + """Decode a chunk of SSE data and return parsed events.""" + import json + + try: + chunk_str = chunk.decode("utf-8") + except UnicodeDecodeError: + # Skip malformed chunks + return [] + + self.buffer += chunk_str + events = [] + + # Process complete lines + while "\n" in self.buffer: + line, self.buffer = self.buffer.split("\n", 1) + line = line.rstrip("\r") # Handle CRLF + + if line.startswith("data: "): + data_str = line[6:].strip() + if data_str == "[DONE]": + events.append({"type": "done"}) + elif data_str: + try: + event_data = json.loads(data_str) + events.append({"type": "data", "data": event_data}) + except json.JSONDecodeError: + # Skip malformed JSON + continue + + return events + + def extract_content(self, event_data: dict) -> str: + """Extract content from event data.""" + return _extract_content_from_chunk(event_data) + + def add_content(self, content: str) -> None: + """Add content to the buffer.""" + if content: + self.content_buffer.append(content) + + def get_complete_content(self) -> str: + """Get the complete buffered content.""" + return "".join(self.content_buffer) + + +def _log_streaming_response(response, response_body: list) -> None: + """Log streaming response with robust SSE parsing.""" + from starlette.concurrency import iterate_in_threadpool + + sse_decoder = SSEDecoder() + chunk_count = 0 + + def buffered_iterator(): + nonlocal chunk_count + + for chunk in response_body: + chunk_count += 1 + yield chunk + + # Parse SSE events from chunk + events = sse_decoder.decode_chunk(chunk) + + for event in events: + if event["type"] == "data": + content = sse_decoder.extract_content(event["data"]) + sse_decoder.add_content(content) + elif event["type"] == "done": + # Log complete content when done + full_content = sse_decoder.get_complete_content() + if full_content: + # Truncate if too long + if len(full_content) > 2048: + full_content = full_content[:2048] + "" + "...[truncated]" + logger.info( + "response_body={streaming_complete: content=%r, chunks=%d}", + full_content, + chunk_count, + ) + else: + logger.info( + "response_body={streaming_complete: no_content, chunks=%d}", + chunk_count, + ) + return + + response.body_iterator = iterate_in_threadpool(buffered_iterator()) + logger.info("response_body={streaming_started: chunks=%d}", len(response_body)) + + +def _log_non_streaming_response(response_body: list) -> None: + """Log non-streaming response.""" + try: + decoded_body = response_body[0].decode() + logger.info("response_body={%s}", decoded_body) + except UnicodeDecodeError: + logger.info("response_body={}") + + +async def log_response(request: Request, call_next): + response = await call_next(request) + response_body = [section async for section in response.body_iterator] + response.body_iterator = iterate_in_threadpool(iter(response_body)) + # Check if this is a streaming response by looking at content-type + content_type = response.headers.get("content-type", "") + is_streaming = content_type == "text/event-stream; charset=utf-8" + + # Log response body based on type + if not response_body: + logger.info("response_body={}") + elif is_streaming: + _log_streaming_response(response, response_body) + else: + _log_non_streaming_response(response_body) + return response + + +async def http_exception_handler(_: Request, exc: HTTPException): + err = ErrorResponse( + error=ErrorInfo( + message=sanitize_message(exc.detail), + type=HTTPStatus(exc.status_code).phrase, + code=exc.status_code, + ) + ) + return JSONResponse(err.model_dump(), status_code=exc.status_code) + + +async def validation_exception_handler(_: Request, exc: RequestValidationError): + param = None + errors = exc.errors() + for error in errors: + if "ctx" in error and "error" in error["ctx"]: + ctx_error = error["ctx"]["error"] + if isinstance(ctx_error, VLLMValidationError): + param = ctx_error.parameter + break + + exc_str = str(exc) + errors_str = str(errors) + + if errors and errors_str and errors_str != exc_str: + message = f"{exc_str} {errors_str}" + else: + message = exc_str + + err = ErrorResponse( + error=ErrorInfo( + message=sanitize_message(message), + type=HTTPStatus.BAD_REQUEST.phrase, + code=HTTPStatus.BAD_REQUEST, + param=param, + ) + ) + return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) + + +_running_tasks: set[asyncio.Task] = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + freeze_gc_heap() + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py index 8e5dd3db2..7c5bae5b5 100644 --- a/vllm/entrypoints/sagemaker/api_router.py +++ b/vllm/entrypoints/sagemaker/api_router.py @@ -10,7 +10,7 @@ import pydantic from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response -from vllm.entrypoints.openai.api_server import base +from vllm.entrypoints.openai.basic.api_router import base from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.utils import validate_json_request @@ -158,3 +158,7 @@ def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]): return JSONResponse(content=res.model_dump(), status_code=res.error.code) app.include_router(router) + + +def sagemaker_standards_bootstrap(app: FastAPI) -> FastAPI: + return sagemaker_standards.bootstrap(app)