diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index c5e82d147..a62340513 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -12,7 +12,8 @@ from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.engine.protocol import ErrorResponse -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput from vllm.tokenizers import get_tokenizer from vllm.v1.engine.async_llm import AsyncLLM diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 0d7e6ae37..dd5d62990 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -6,7 +6,7 @@ import json import pytest from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args -from vllm.entrypoints.openai.serving_models import LoRAModulePath +from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.utils.argparse_utils import FlexibleArgumentParser from ...utils import VLLM_PATH diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/test_completion_error.py index 9b4539d47..7dd0448de 100644 --- a/tests/entrypoints/openai/test_completion_error.py +++ b/tests/entrypoints/openai/test_completion_error.py @@ -9,9 +9,11 @@ from unittest.mock import AsyncMock, MagicMock import pytest from vllm.config.multimodal import MultiModalConfig -from vllm.entrypoints.openai.engine.protocol import CompletionRequest, ErrorResponse -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.completion.protocol import CompletionRequest +from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput from vllm.tokenizers import get_tokenizer from vllm.v1.engine.async_llm import AsyncLLM diff --git a/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py b/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py index d8ee91f77..3d8ea3ef3 100644 --- a/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py +++ b/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py @@ -8,13 +8,11 @@ from unittest.mock import Mock import pytest -from vllm.entrypoints.openai.engine.protocol import ( - StructuredOutputsParams, -) from vllm.entrypoints.tool_server import ToolServer from vllm.reasoning.gptoss_reasoning_parser import ( GptOssReasoningParser, ) +from vllm.sampling_params import StructuredOutputsParams class TestGptOssStructuralTagsIntegration: diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index f740e7968..bfadf51e4 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -9,9 +9,11 @@ from unittest.mock import AsyncMock, MagicMock import pytest from vllm.config.multimodal import MultiModalConfig -from vllm.entrypoints.openai.engine.protocol import CompletionRequest, ErrorResponse -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.completion.protocol import CompletionRequest +from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.tokenizers import get_tokenizer diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index d4f7b82a5..2e0b0a63f 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -20,8 +20,8 @@ from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, RequestResponseMetadata, ) +from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.parser.harmony_utils import get_encoding -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.outputs import CompletionOutput, RequestOutput from vllm.tokenizers import get_tokenizer from vllm.tool_parsers import ToolParserManager diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index c2bc82514..654d42276 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -9,7 +9,7 @@ import pytest from vllm.config import ModelConfig from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.tokenizers.mistral import MistralTokenizer diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index a671611c7..88b168c7d 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -11,7 +11,8 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.serve.lora.protocol import ( LoadLoRAAdapterRequest, UnloadLoRAAdapterRequest, diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 270092faf..adebafec2 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -19,7 +19,8 @@ from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.inputs import PromptType from vllm.outputs import RequestOutput from vllm.platforms import current_platform diff --git a/vllm/entrypoints/anthropic/__init__.py b/vllm/entrypoints/anthropic/__init__.py index e69de29bb..208f01a7c 100644 --- a/vllm/entrypoints/anthropic/__init__.py +++ b/vllm/entrypoints/anthropic/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/entrypoints/anthropic/api_router.py b/vllm/entrypoints/anthropic/api_router.py new file mode 100644 index 000000000..1494dd7e5 --- /dev/null +++ b/vllm/entrypoints/anthropic/api_router.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.entrypoints.anthropic.protocol import ( + AnthropicError, + AnthropicErrorResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from vllm.entrypoints.anthropic.serving import AnthropicServingMessages +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.utils import ( + load_aware_call, + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + +router = APIRouter() + + +def messages(request: Request) -> AnthropicServingMessages: + return request.app.state.anthropic_serving_messages + + +@router.post( + "/v1/messages", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): + def translate_error_response(response: ErrorResponse) -> JSONResponse: + anthropic_error = AnthropicErrorResponse( + error=AnthropicError( + type=response.error.type, + message=response.error.message, + ) + ) + return JSONResponse( + status_code=response.error.code, content=anthropic_error.model_dump() + ) + + handler = messages(raw_request) + if handler is None: + base_server = raw_request.app.state.openai_serving_tokenization + error = base_server.create_error_response( + message="The model does not support Messages API" + ) + return translate_error_response(error) + + try: + generator = await handler.create_messages(request, raw_request) + except Exception as e: + logger.exception("Error in create_messages: %s", e) + return JSONResponse( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + content=AnthropicErrorResponse( + error=AnthropicError( + type="internal_error", + message=str(e), + ) + ).model_dump(), + ) + + if isinstance(generator, ErrorResponse): + return translate_error_response(generator) + + elif isinstance(generator, AnthropicMessagesResponse): + resp = generator.model_dump(exclude_none=True) + logger.debug("Anthropic Messages Response: %s", resp) + return JSONResponse(content=resp) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/anthropic/serving_messages.py b/vllm/entrypoints/anthropic/serving.py similarity index 99% rename from vllm/entrypoints/anthropic/serving_messages.py rename to vllm/entrypoints/anthropic/serving.py index 5177d50f7..7f53b1ef3 100644 --- a/vllm/entrypoints/anthropic/serving_messages.py +++ b/vllm/entrypoints/anthropic/serving.py @@ -37,7 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, StreamOptions, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels logger = logging.getLogger(__name__) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 39d07f6d2..01676bf2a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -22,10 +22,10 @@ from typing import Any import model_hosting_container_standards.sagemaker as sagemaker_standards import pydantic import uvloop -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import JSONResponse from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders, State from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -33,36 +33,26 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient -from vllm.entrypoints.anthropic.protocol import ( - AnthropicError, - AnthropicErrorResponse, - AnthropicMessagesRequest, - AnthropicMessagesResponse, -) -from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages +from vllm.entrypoints.anthropic.serving import AnthropicServingMessages from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion from vllm.entrypoints.openai.engine.protocol import ( - CompletionRequest, - CompletionResponse, ErrorInfo, ErrorResponse, ) from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.entrypoints.openai.orca_metrics import metrics_header -from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses -from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import ( - BaseModelPath, +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import ( OpenAIServingModels, ) +from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses from vllm.entrypoints.openai.translations.serving import ( OpenAIServingTranscription, OpenAIServingTranslation, ) -from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling @@ -75,12 +65,10 @@ from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.utils import ( cli_env_setup, - load_aware_call, log_non_default_args, process_chat_template, process_lora_modules, sanitize_message, - with_cancellation, ) from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger @@ -99,7 +87,6 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger("vllm.entrypoints.openai.api_server") -ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format" _running_tasks: set[asyncio.Task] = set() @@ -231,22 +218,6 @@ def base(request: Request) -> OpenAIServing: return tokenization(request) -def models(request: Request) -> OpenAIServingModels: - return request.app.state.openai_serving_models - - -def messages(request: Request) -> AnthropicServingMessages: - return request.app.state.anthropic_serving_messages - - -def chat(request: Request) -> OpenAIServingChat | None: - return request.app.state.openai_serving_chat - - -def completion(request: Request) -> OpenAIServingCompletion | None: - return request.app.state.openai_serving_completion - - def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization @@ -278,116 +249,12 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.get("/v1/models") -async def show_available_models(raw_request: Request): - handler = models(raw_request) - - models_ = await handler.show_available_models() - return JSONResponse(content=models_.model_dump()) - - @router.get("/version") async def show_version(): ver = {"version": VLLM_VERSION} return JSONResponse(content=ver) -@router.post( - "/v1/messages", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): - def translate_error_response(response: ErrorResponse) -> JSONResponse: - anthropic_error = AnthropicErrorResponse( - error=AnthropicError( - type=response.error.type, - message=response.error.message, - ) - ) - return JSONResponse( - status_code=response.error.code, content=anthropic_error.model_dump() - ) - - handler = messages(raw_request) - if handler is None: - error = base(raw_request).create_error_response( - message="The model does not support Messages API" - ) - return translate_error_response(error) - - try: - generator = await handler.create_messages(request, raw_request) - except Exception as e: - logger.exception("Error in create_messages: %s", e) - return JSONResponse( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - content=AnthropicErrorResponse( - error=AnthropicError( - type="internal_error", - message=str(e), - ) - ).model_dump(), - ) - - if isinstance(generator, ErrorResponse): - return translate_error_response(generator) - - elif isinstance(generator, AnthropicMessagesResponse): - resp = generator.model_dump(exclude_none=True) - logger.debug("Anthropic Messages Response: %s", resp) - return JSONResponse(content=resp) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -@router.post( - "/v1/completions", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def create_completion(request: CompletionRequest, raw_request: Request): - metrics_header_format = raw_request.headers.get( - ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, "" - ) - handler = completion(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support Completions API" - ) - - try: - generator = await handler.create_completion(request, raw_request) - except Exception as e: - return handler.create_error_response(e) - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, CompletionResponse): - return JSONResponse( - content=generator.model_dump(), - headers=metrics_header(metrics_header_format), - ) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -486,7 +353,7 @@ def _extract_content_from_chunk(chunk_data: dict) -> str: from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionStreamResponse, ) - from vllm.entrypoints.openai.engine.protocol import ( + from vllm.entrypoints.openai.completion.protocol import ( CompletionStreamResponse, ) @@ -646,6 +513,22 @@ def build_app(args: Namespace) -> FastAPI: ) register_translations_api_router(app) + + from vllm.entrypoints.openai.completion.api_router import ( + attach_router as register_completion_api_router, + ) + + register_completion_api_router(app) + from vllm.entrypoints.anthropic.api_router import ( + attach_router as register_anthropic_api_router, + ) + + register_anthropic_api_router(app) + from vllm.entrypoints.openai.models.api_router import ( + attach_router as register_models_api_router, + ) + + register_models_api_router(app) from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes register_sagemaker_routes(router) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index e161e407e..83d75a00f 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -54,6 +54,7 @@ from vllm.entrypoints.openai.engine.serving import ( OpenAIServing, clamp_prompt_logprobs, ) +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.parser.harmony_utils import ( get_developer_message, get_stop_tokens_for_assistant_actions, @@ -63,7 +64,6 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( parse_chat_output, render_for_completion, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import TokensPrompt diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 594130a1a..bafa845cc 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -26,7 +26,7 @@ from vllm.entrypoints.constants import ( H11_MAX_HEADER_COUNT_DEFAULT, H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, ) -from vllm.entrypoints.openai.serving_models import LoRAModulePath +from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.logger import init_logger from vllm.tool_parsers import ToolParserManager from vllm.utils.argparse_utils import FlexibleArgumentParser diff --git a/vllm/entrypoints/openai/completion/__init__.py b/vllm/entrypoints/openai/completion/__init__.py new file mode 100644 index 000000000..208f01a7c --- /dev/null +++ b/vllm/entrypoints/openai/completion/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/entrypoints/openai/completion/api_router.py b/vllm/entrypoints/openai/completion/api_router.py new file mode 100644 index 000000000..e6783eed2 --- /dev/null +++ b/vllm/entrypoints/openai/completion/api_router.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.entrypoints.openai.completion.protocol import ( + CompletionRequest, + CompletionResponse, +) +from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.orca_metrics import metrics_header +from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.utils import ( + load_aware_call, + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + +router = APIRouter() +ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format" + + +def completion(request: Request) -> OpenAIServingCompletion | None: + return request.app.state.openai_serving_completion + + +@router.post( + "/v1/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_completion(request: CompletionRequest, raw_request: Request): + metrics_header_format = raw_request.headers.get( + ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, "" + ) + handler = completion(raw_request) + if handler is None: + base_server = raw_request.app.state.openai_serving_tokenization + return base_server.create_error_response( + message="The model does not support Completions API" + ) + + try: + generator = await handler.create_completion(request, raw_request) + except Exception as e: + return handler.create_error_response(e) + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, CompletionResponse): + return JSONResponse( + content=generator.model_dump(), + headers=metrics_header(metrics_header_format), + ) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py new file mode 100644 index 000000000..df432aea1 --- /dev/null +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -0,0 +1,463 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import json +import time +from typing import Annotated, Any, Literal + +import torch +from pydantic import ( + Field, + model_validator, +) + +from vllm.entrypoints.openai.engine.protocol import ( + AnyResponseFormat, + LegacyStructuralTagResponseFormat, + LogitsProcessors, + OpenAIBaseModel, + StreamOptions, + StructuralTagResponseFormat, + UsageInfo, + get_logits_processors, +) +from vllm.exceptions import VLLMValidationError +from vllm.logger import init_logger +from vllm.logprobs import Logprob +from vllm.sampling_params import ( + BeamSearchParams, + RequestOutputKind, + SamplingParams, + StructuredOutputsParams, +) +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +_LONG_INFO = torch.iinfo(torch.long) + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str | None = None + prompt: list[int] | list[list[int]] | str | list[str] | None = None + echo: bool | None = False + frequency_penalty: float | None = 0.0 + logit_bias: dict[str, float] | None = None + logprobs: int | None = None + max_tokens: int | None = 16 + n: int = 1 + presence_penalty: float | None = 0.0 + seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: str | list[str] | None = [] + stream: bool | None = False + stream_options: StreamOptions | None = None + suffix: str | None = None + temperature: float | None = None + top_p: float | None = None + user: str | None = None + + # --8<-- [start:completion-sampling-params] + use_beam_search: bool = False + top_k: int | None = None + min_p: float | None = None + repetition_penalty: float | None = None + length_penalty: float = 1.0 + stop_token_ids: list[int] | None = [] + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = ( + None + ) + allowed_token_ids: list[int] | None = None + prompt_logprobs: int | None = None + # --8<-- [end:completion-sampling-params] + + # --8<-- [start:completion-extra-params] + prompt_embeds: bytes | list[bytes] | None = None + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt." + ), + ) + response_format: AnyResponseFormat | None = Field( + default=None, + description=( + "Similar to chat completion, this parameter specifies the format " + "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}" + ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." + ), + ) + structured_outputs: StructuredOutputsParams | None = Field( + default=None, + description="Additional kwargs for structured outputs", + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + request_id: str = Field( + default_factory=random_uuid, + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + logits_processors: LogitsProcessors | None = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}." + ), + ) + + return_tokens_as_token_ids: bool | None = Field( + default=None, + description=( + "If specified with 'logprobs', tokens are represented " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified." + ), + ) + return_token_ids: bool | None = Field( + default=None, + description=( + "If specified, the result will include token IDs alongside the " + "generated text. In streaming mode, prompt_token_ids is included " + "only in the first chunk, and token_ids contains the delta tokens " + "for each chunk. This is useful for debugging or when you " + "need to map generated text back to input tokens." + ), + ) + + cache_salt: str | None = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit)." + ), + ) + + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) + + vllm_xargs: dict[str, str | int | float] | None = Field( + default=None, + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), + ) + + # --8<-- [end:completion-extra-params] + + # Default sampling parameters for completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": 0, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + max_tokens: int, + default_sampling_params: dict | None = None, + ) -> BeamSearchParams: + if default_sampling_params is None: + default_sampling_params = {} + n = self.n if self.n is not None else 1 + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 1.0) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output, + ) + + def to_sampling_params( + self, + max_tokens: int, + logits_processor_pattern: str | None, + default_sampling_params: dict | None = None, + ) -> SamplingParams: + if default_sampling_params is None: + default_sampling_params = {} + + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + + echo_without_generation = self.echo and self.max_tokens == 0 + + response_format = self.response_format + if response_format is not None: + # If structured outputs wasn't already enabled, + # we must enable it for these features to work + if self.structured_outputs is None: + self.structured_outputs = StructuredOutputsParams() + + # Set structured output params for response format + if response_format.type == "json_object": + self.structured_outputs.json_object = True + elif response_format.type == "json_schema": + json_schema = response_format.json_schema + assert json_schema is not None + self.structured_outputs.json = json_schema.json_schema + elif response_format.type == "structural_tag": + structural_tag = response_format + assert structural_tag is not None and isinstance( + structural_tag, + ( + LegacyStructuralTagResponseFormat, + StructuralTagResponseFormat, + ), + ) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structured_outputs.structural_tag = json.dumps(s_tag_obj) + + extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} + if self.kv_transfer_params: + # Pass in kv_transfer_params via extra_args + extra_args["kv_transfer_params"] = self.kv_transfer_params + return SamplingParams.from_optional( + n=self.n, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + prompt_logprobs=prompt_logprobs, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + logits_processors=get_logits_processors( + self.logits_processors, logits_processor_pattern + ), + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + structured_outputs=self.structured_outputs, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids, + extra_args=extra_args or None, + skip_clone=True, # Created fresh per request, safe to skip clone + ) + + @model_validator(mode="before") + @classmethod + def check_structured_outputs_count(cls, data): + if data.get("structured_outputs", None) is None: + return data + + structured_outputs_kwargs = data["structured_outputs"] + count = sum( + structured_outputs_kwargs.get(k) is not None + for k in ("json", "regex", "choice") + ) + if count > 1: + raise VLLMValidationError( + "You can only use one kind of constraints for structured " + "outputs ('json', 'regex' or 'choice').", + parameter="structured_outputs", + ) + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): + raise VLLMValidationError( + "`prompt_logprobs` are not available when `stream=True`.", + parameter="prompt_logprobs", + ) + + if prompt_logprobs < 0 and prompt_logprobs != -1: + raise VLLMValidationError( + "`prompt_logprobs` must be a positive value or -1.", + parameter="prompt_logprobs", + value=prompt_logprobs, + ) + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise VLLMValidationError( + "`logprobs` must be a positive value.", + parameter="logprobs", + value=logprobs, + ) + + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise VLLMValidationError( + "Stream options can only be defined when `stream=True`.", + parameter="stream_options", + ) + + return data + + @model_validator(mode="before") + @classmethod + def validate_prompt_and_prompt_embeds(cls, data): + prompt = data.get("prompt") + prompt_embeds = data.get("prompt_embeds") + + prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "") + embeds_is_empty = prompt_embeds is None or ( + isinstance(prompt_embeds, list) and len(prompt_embeds) == 0 + ) + + if prompt_is_empty and embeds_is_empty: + raise ValueError( + "Either prompt or prompt_embeds must be provided and non-empty." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None and ( + not isinstance(data["cache_salt"], str) or not data["cache_salt"] + ): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) + return data + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: list[int] = Field(default_factory=list) + token_logprobs: list[float | None] = Field(default_factory=list) + tokens: list[str] = Field(default_factory=list) + top_logprobs: list[dict[str, float] | None] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: CompletionLogProbs | None = None + finish_reason: str | None = None + stop_reason: int | str | None = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token" + ), + ) + token_ids: list[int] | None = None # For response + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + prompt_token_ids: list[int] | None = None # For prompt + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: Literal["text_completion"] = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[CompletionResponseChoice] + service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None + system_fingerprint: str | None = None + usage: UsageInfo + + # vLLM-specific fields that are not in OpenAI spec + kv_transfer_params: dict[str, Any] | None = Field( + default=None, description="KVTransfer parameters." + ) + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: CompletionLogProbs | None = None + finish_reason: str | None = None + stop_reason: int | str | None = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token" + ), + ) + # not part of the OpenAI spec but for tracing the tokens + # prompt tokens is put into choice to align with CompletionResponseChoice + prompt_token_ids: list[int] | None = None + token_ids: list[int] | None = None + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[CompletionResponseStreamChoice] + usage: UsageInfo | None = Field(default=None) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/completion/serving.py similarity index 99% rename from vllm/entrypoints/openai/serving_completion.py rename to vllm/entrypoints/openai/completion/serving.py index 187ccb64e..2c573f77e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -12,27 +12,29 @@ from fastapi import Request from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.protocol import ( +from vllm.entrypoints.openai.completion.protocol import ( CompletionLogProbs, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, +) +from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo, - VLLMValidationError, ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, OpenAIServing, clamp_prompt_logprobs, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens, should_include_usage +from vllm.exceptions import VLLMValidationError from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.logger import init_logger from vllm.logprobs import Logprob diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index 2fe76e8db..1f117a4ee 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -3,9 +3,8 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py -import json import time -from typing import Annotated, Any, ClassVar, Literal, TypeAlias +from typing import Any, ClassVar, Literal, TypeAlias import regex as re import torch @@ -17,14 +16,9 @@ from pydantic import ( ) from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger -from vllm.logprobs import Logprob from vllm.sampling_params import ( - BeamSearchParams, - RequestOutputKind, SamplingParams, - StructuredOutputsParams, ) from vllm.utils import random_uuid from vllm.utils.import_utils import resolve_obj_by_qualname @@ -226,429 +220,6 @@ def get_logits_processors( return None -class CompletionRequest(OpenAIBaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/completions/create - model: str | None = None - prompt: list[int] | list[list[int]] | str | list[str] | None = None - echo: bool | None = False - frequency_penalty: float | None = 0.0 - logit_bias: dict[str, float] | None = None - logprobs: int | None = None - max_tokens: int | None = 16 - n: int = 1 - presence_penalty: float | None = 0.0 - seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: str | list[str] | None = [] - stream: bool | None = False - stream_options: StreamOptions | None = None - suffix: str | None = None - temperature: float | None = None - top_p: float | None = None - user: str | None = None - - # --8<-- [start:completion-sampling-params] - use_beam_search: bool = False - top_k: int | None = None - min_p: float | None = None - repetition_penalty: float | None = None - length_penalty: float = 1.0 - stop_token_ids: list[int] | None = [] - include_stop_str_in_output: bool = False - ignore_eos: bool = False - min_tokens: int = 0 - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = ( - None - ) - allowed_token_ids: list[int] | None = None - prompt_logprobs: int | None = None - # --8<-- [end:completion-sampling-params] - - # --8<-- [start:completion-extra-params] - prompt_embeds: bytes | list[bytes] | None = None - add_special_tokens: bool = Field( - default=True, - description=( - "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt." - ), - ) - response_format: AnyResponseFormat | None = Field( - default=None, - description=( - "Similar to chat completion, this parameter specifies the format " - "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}" - ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." - ), - ) - structured_outputs: StructuredOutputsParams | None = Field( - default=None, - description="Additional kwargs for structured outputs", - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - request_id: str = Field( - default_factory=random_uuid, - description=( - "The request_id related to this request. If the caller does " - "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - logits_processors: LogitsProcessors | None = Field( - default=None, - description=( - "A list of either qualified names of logits processors, or " - "constructor objects, to apply when sampling. A constructor is " - "a JSON object with a required 'qualname' field specifying the " - "qualified name of the processor class/factory, and optional " - "'args' and 'kwargs' fields containing positional and keyword " - "arguments. For example: {'qualname': " - "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}." - ), - ) - - return_tokens_as_token_ids: bool | None = Field( - default=None, - description=( - "If specified with 'logprobs', tokens are represented " - " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified." - ), - ) - return_token_ids: bool | None = Field( - default=None, - description=( - "If specified, the result will include token IDs alongside the " - "generated text. In streaming mode, prompt_token_ids is included " - "only in the first chunk, and token_ids contains the delta tokens " - "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens." - ), - ) - - cache_salt: str | None = Field( - default=None, - description=( - "If specified, the prefix cache will be salted with the provided " - "string to prevent an attacker to guess prompts in multi-user " - "environments. The salt should be random, protected from " - "access by 3rd parties, and long enough to be " - "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit)." - ), - ) - - kv_transfer_params: dict[str, Any] | None = Field( - default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) - - vllm_xargs: dict[str, str | int | float] | None = Field( - default=None, - description=( - "Additional request parameters with string or " - "numeric values, used by custom extensions." - ), - ) - - # --8<-- [end:completion-extra-params] - - # Default sampling parameters for completion requests - _DEFAULT_SAMPLING_PARAMS: dict = { - "repetition_penalty": 1.0, - "temperature": 1.0, - "top_p": 1.0, - "top_k": 0, - "min_p": 0.0, - } - - def to_beam_search_params( - self, - max_tokens: int, - default_sampling_params: dict | None = None, - ) -> BeamSearchParams: - if default_sampling_params is None: - default_sampling_params = {} - n = self.n if self.n is not None else 1 - - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get("temperature", 1.0) - - return BeamSearchParams( - beam_width=n, - max_tokens=max_tokens, - ignore_eos=self.ignore_eos, - temperature=temperature, - length_penalty=self.length_penalty, - include_stop_str_in_output=self.include_stop_str_in_output, - ) - - def to_sampling_params( - self, - max_tokens: int, - logits_processor_pattern: str | None, - default_sampling_params: dict | None = None, - ) -> SamplingParams: - if default_sampling_params is None: - default_sampling_params = {} - - # Default parameters - if (repetition_penalty := self.repetition_penalty) is None: - repetition_penalty = default_sampling_params.get( - "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], - ) - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] - ) - if (top_p := self.top_p) is None: - top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] - ) - if (top_k := self.top_k) is None: - top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] - ) - if (min_p := self.min_p) is None: - min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] - ) - - prompt_logprobs = self.prompt_logprobs - if prompt_logprobs is None and self.echo: - prompt_logprobs = self.logprobs - - echo_without_generation = self.echo and self.max_tokens == 0 - - response_format = self.response_format - if response_format is not None: - # If structured outputs wasn't already enabled, - # we must enable it for these features to work - if self.structured_outputs is None: - self.structured_outputs = StructuredOutputsParams() - - # Set structured output params for response format - if response_format.type == "json_object": - self.structured_outputs.json_object = True - elif response_format.type == "json_schema": - json_schema = response_format.json_schema - assert json_schema is not None - self.structured_outputs.json = json_schema.json_schema - elif response_format.type == "structural_tag": - structural_tag = response_format - assert structural_tag is not None and isinstance( - structural_tag, - ( - LegacyStructuralTagResponseFormat, - StructuralTagResponseFormat, - ), - ) - s_tag_obj = structural_tag.model_dump(by_alias=True) - self.structured_outputs.structural_tag = json.dumps(s_tag_obj) - - extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} - if self.kv_transfer_params: - # Pass in kv_transfer_params via extra_args - extra_args["kv_transfer_params"] = self.kv_transfer_params - return SamplingParams.from_optional( - n=self.n, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - seed=self.seed, - stop=self.stop, - stop_token_ids=self.stop_token_ids, - logprobs=self.logprobs, - ignore_eos=self.ignore_eos, - max_tokens=max_tokens if not echo_without_generation else 1, - min_tokens=self.min_tokens, - prompt_logprobs=prompt_logprobs, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self.spaces_between_special_tokens, - include_stop_str_in_output=self.include_stop_str_in_output, - logits_processors=get_logits_processors( - self.logits_processors, logits_processor_pattern - ), - truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream - else RequestOutputKind.FINAL_ONLY, - structured_outputs=self.structured_outputs, - logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids, - extra_args=extra_args or None, - skip_clone=True, # Created fresh per request, safe to skip clone - ) - - @model_validator(mode="before") - @classmethod - def check_structured_outputs_count(cls, data): - if data.get("structured_outputs", None) is None: - return data - - structured_outputs_kwargs = data["structured_outputs"] - count = sum( - structured_outputs_kwargs.get(k) is not None - for k in ("json", "regex", "choice") - ) - if count > 1: - raise VLLMValidationError( - "You can only use one kind of constraints for structured " - "outputs ('json', 'regex' or 'choice').", - parameter="structured_outputs", - ) - return data - - @model_validator(mode="before") - @classmethod - def check_logprobs(cls, data): - if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): - raise VLLMValidationError( - "`prompt_logprobs` are not available when `stream=True`.", - parameter="prompt_logprobs", - ) - - if prompt_logprobs < 0 and prompt_logprobs != -1: - raise VLLMValidationError( - "`prompt_logprobs` must be a positive value or -1.", - parameter="prompt_logprobs", - value=prompt_logprobs, - ) - if (logprobs := data.get("logprobs")) is not None and logprobs < 0: - raise VLLMValidationError( - "`logprobs` must be a positive value.", - parameter="logprobs", - value=logprobs, - ) - - return data - - @model_validator(mode="before") - @classmethod - def validate_stream_options(cls, data): - if data.get("stream_options") and not data.get("stream"): - raise VLLMValidationError( - "Stream options can only be defined when `stream=True`.", - parameter="stream_options", - ) - - return data - - @model_validator(mode="before") - @classmethod - def validate_prompt_and_prompt_embeds(cls, data): - prompt = data.get("prompt") - prompt_embeds = data.get("prompt_embeds") - - prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "") - embeds_is_empty = prompt_embeds is None or ( - isinstance(prompt_embeds, list) and len(prompt_embeds) == 0 - ) - - if prompt_is_empty and embeds_is_empty: - raise ValueError( - "Either prompt or prompt_embeds must be provided and non-empty." - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None and ( - not isinstance(data["cache_salt"], str) or not data["cache_salt"] - ): - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) - return data - - -class CompletionLogProbs(OpenAIBaseModel): - text_offset: list[int] = Field(default_factory=list) - token_logprobs: list[float | None] = Field(default_factory=list) - tokens: list[str] = Field(default_factory=list) - top_logprobs: list[dict[str, float] | None] = Field(default_factory=list) - - -class CompletionResponseChoice(OpenAIBaseModel): - index: int - text: str - logprobs: CompletionLogProbs | None = None - finish_reason: str | None = None - stop_reason: int | str | None = Field( - default=None, - description=( - "The stop string or token id that caused the completion " - "to stop, None if the completion finished for some other reason " - "including encountering the EOS token" - ), - ) - token_ids: list[int] | None = None # For response - prompt_logprobs: list[dict[int, Logprob] | None] | None = None - prompt_token_ids: list[int] | None = None # For prompt - - -class CompletionResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") - object: Literal["text_completion"] = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[CompletionResponseChoice] - service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None - system_fingerprint: str | None = None - usage: UsageInfo - - # vLLM-specific fields that are not in OpenAI spec - kv_transfer_params: dict[str, Any] | None = Field( - default=None, description="KVTransfer parameters." - ) - - -class CompletionResponseStreamChoice(OpenAIBaseModel): - index: int - text: str - logprobs: CompletionLogProbs | None = None - finish_reason: str | None = None - stop_reason: int | str | None = Field( - default=None, - description=( - "The stop string or token id that caused the completion " - "to stop, None if the completion finished for some other reason " - "including encountering the EOS token" - ), - ) - # not part of the OpenAI spec but for tracing the tokens - # prompt tokens is put into choice to align with CompletionResponseChoice - prompt_token_ids: list[int] | None = None - token_ids: list[int] | None = None - - -class CompletionStreamResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") - object: str = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[CompletionResponseStreamChoice] - usage: UsageInfo | None = Field(default=None) - - class FunctionCall(OpenAIBaseModel): name: str arguments: str diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f2a2a4bcd..fccb22dce 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -43,20 +43,21 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ChatCompletionResponse, ) -from vllm.entrypoints.openai.engine.protocol import ( +from vllm.entrypoints.openai.completion.protocol import ( CompletionRequest, CompletionResponse, +) +from vllm.entrypoints.openai.engine.protocol import ( ErrorInfo, ErrorResponse, FunctionCall, FunctionDefinition, - VLLMValidationError, ) +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.responses.protocol import ( ResponseInputOutputItem, ResponsesRequest, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.translations.protocol import ( TranscriptionRequest, TranscriptionResponse, @@ -95,6 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( TokenizeResponse, ) from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message +from vllm.exceptions import VLLMValidationError from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import ( PromptComponents, diff --git a/vllm/entrypoints/openai/models/__init__.py b/vllm/entrypoints/openai/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/entrypoints/openai/models/api_router.py b/vllm/entrypoints/openai/models/api_router.py new file mode 100644 index 000000000..2edda9c3e --- /dev/null +++ b/vllm/entrypoints/openai/models/api_router.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse + +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + +router = APIRouter() + + +def models(request: Request) -> OpenAIServingModels: + return request.app.state.openai_serving_models + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request): + handler = models(raw_request) + + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/openai/models/protocol.py b/vllm/entrypoints/openai/models/protocol.py new file mode 100644 index 000000000..e7b96476c --- /dev/null +++ b/vllm/entrypoints/openai/models/protocol.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from dataclasses import dataclass + + +@dataclass +class BaseModelPath: + name: str + model_path: str + + +@dataclass +class LoRAModulePath: + name: str + path: str + base_model_name: str | None = None diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/models/serving.py similarity index 98% rename from vllm/entrypoints/openai/serving_models.py rename to vllm/entrypoints/openai/models/serving.py index 614a6fc32..2d8cf8f33 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/models/serving.py @@ -3,7 +3,6 @@ from asyncio import Lock from collections import defaultdict -from dataclasses import dataclass from http import HTTPStatus from vllm.engine.protocol import EngineClient @@ -14,6 +13,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ModelList, ModelPermission, ) +from vllm.entrypoints.openai.models.protocol import BaseModelPath, LoRAModulePath from vllm.entrypoints.serve.lora.protocol import ( LoadLoRAAdapterRequest, UnloadLoRAAdapterRequest, @@ -27,19 +27,6 @@ from vllm.utils.counter import AtomicCounter logger = init_logger(__name__) -@dataclass -class BaseModelPath: - name: str - model_path: str - - -@dataclass -class LoRAModulePath: - name: str - path: str - base_model_name: str | None = None - - class OpenAIServingModels: """Shared instance to hold data about the loaded base model(s) and adapters. diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 8be40f741..2e5c0baa9 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -76,12 +76,12 @@ from vllm.entrypoints.openai.engine.protocol import ( DeltaMessage, ErrorResponse, RequestResponseMetadata, - VLLMValidationError, ) from vllm.entrypoints.openai.engine.serving import ( GenerationError, OpenAIServing, ) +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.parser.harmony_utils import ( construct_harmony_previous_input_messages, get_developer_message, @@ -108,7 +108,6 @@ from vllm.entrypoints.openai.responses.protocol import ( ResponseUsage, StreamingResponsesResponse, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.responses_utils import ( construct_input_messages, construct_tool_dicts, @@ -116,6 +115,7 @@ from vllm.entrypoints.responses_utils import ( should_continue_final_message, ) from vllm.entrypoints.tool_server import ToolServer +from vllm.exceptions import VLLMValidationError from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 5b72dc663..6f7da404a 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -28,7 +28,8 @@ from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, OpenAIBaseModel, ) -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.models.protocol import BaseModelPath +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.score.protocol import ( diff --git a/vllm/entrypoints/openai/translations/api_router.py b/vllm/entrypoints/openai/translations/api_router.py index 86dc3dc66..dcc64a628 100644 --- a/vllm/entrypoints/openai/translations/api_router.py +++ b/vllm/entrypoints/openai/translations/api_router.py @@ -5,7 +5,7 @@ from http import HTTPStatus from typing import Annotated -from fastapi import APIRouter, FastAPI, Form, HTTPException, Request +from fastapi import APIRouter, FastAPI, Form, Request from fastapi.responses import JSONResponse, StreamingResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse @@ -63,10 +63,7 @@ async def create_transcriptions( try: generator = await handler.create_transcription(audio_data, request, raw_request) except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - + return handler.create_error_response(e) if isinstance(generator, ErrorResponse): return JSONResponse( content=generator.model_dump(), status_code=generator.error.code @@ -103,9 +100,7 @@ async def create_translations( try: generator = await handler.create_translation(audio_data, request, raw_request) except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e + return handler.create_error_response(e) if isinstance(generator, ErrorResponse): return JSONResponse( diff --git a/vllm/entrypoints/openai/translations/serving.py b/vllm/entrypoints/openai/translations/serving.py index 6cbd4c265..646789bba 100644 --- a/vllm/entrypoints/openai/translations/serving.py +++ b/vllm/entrypoints/openai/translations/serving.py @@ -10,7 +10,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, RequestResponseMetadata, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.translations.protocol import ( TranscriptionRequest, TranscriptionResponse, diff --git a/vllm/entrypoints/openai/translations/speech_to_text.py b/vllm/entrypoints/openai/translations/speech_to_text.py index 08cfd8c29..48086c030 100644 --- a/vllm/entrypoints/openai/translations/speech_to_text.py +++ b/vllm/entrypoints/openai/translations/speech_to_text.py @@ -22,7 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( UsageInfo, ) from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.translations.protocol import ( TranscriptionResponse, TranscriptionResponseStreamChoice, diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 446366880..7da2210c6 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -23,7 +23,7 @@ from vllm.entrypoints.openai.engine.serving import ( OpenAIServing, ServeContext, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.classify.protocol import ( ClassificationChatRequest, ClassificationCompletionRequest, diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 6e1381878..b48e3a016 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -21,7 +21,7 @@ from vllm.entrypoints.openai.engine.serving import ( OpenAIServing, ServeContext, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingBytesResponse, EmbeddingChatRequest, diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index c27c9179e..1a2bfd770 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -19,7 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import ( UsageInfo, ) from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.pooling.protocol import ( IOProcessorRequest, IOProcessorResponse, diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index e44f15e66..b798511e9 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -14,7 +14,7 @@ from vllm.entrypoints.openai.engine.protocol import ( UsageInfo, ) from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.score.protocol import ( RerankDocument, RerankRequest, diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py index f2668baec..b00c6d82c 100644 --- a/vllm/entrypoints/sagemaker/routes.py +++ b/vllm/entrypoints/sagemaker/routes.py @@ -12,22 +12,26 @@ from fastapi.responses import JSONResponse, Response from vllm.entrypoints.openai.api_server import ( base, - chat, - completion, - create_completion, - validate_json_request, ) from vllm.entrypoints.openai.chat_completion.api_router import ( + chat, create_chat_completion, ) from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) -from vllm.entrypoints.openai.engine.protocol import ( +from vllm.entrypoints.openai.completion.api_router import ( + completion, + create_completion, +) +from vllm.entrypoints.openai.completion.protocol import ( CompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.api_router import classify, create_classify from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding diff --git a/vllm/entrypoints/serve/disagg/api_router.py b/vllm/entrypoints/serve/disagg/api_router.py index 0b1d1e50a..08542ec5e 100644 --- a/vllm/entrypoints/serve/disagg/api_router.py +++ b/vllm/entrypoints/serve/disagg/api_router.py @@ -10,10 +10,10 @@ from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Respons from fastapi.responses import JSONResponse, StreamingResponse from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.api_server import validate_json_request from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) +from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, GenerateResponse, diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index 68c39f906..659cf9e34 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -6,10 +6,10 @@ from pydantic import BaseModel, Field from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs from vllm.entrypoints.openai.engine.protocol import ( - Logprob, SamplingParams, StreamOptions, ) +from vllm.logprobs import Logprob from vllm.utils import random_uuid diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index 8649bc668..5253040c5 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -23,7 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import ( UsageInfo, ) from vllm.entrypoints.openai.engine.serving import OpenAIServing, clamp_prompt_logprobs -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, GenerateResponse, diff --git a/vllm/entrypoints/serve/elastic_ep/api_router.py b/vllm/entrypoints/serve/elastic_ep/api_router.py index 1a3b57d4c..00e38b611 100644 --- a/vllm/entrypoints/serve/elastic_ep/api_router.py +++ b/vllm/entrypoints/serve/elastic_ep/api_router.py @@ -9,10 +9,10 @@ from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.api_server import validate_json_request from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) +from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.serve.elastic_ep.middleware import ( get_scaling_elastic_ep, set_scaling_elastic_ep, diff --git a/vllm/entrypoints/serve/lora/api_router.py b/vllm/entrypoints/serve/lora/api_router.py index dd6f692ce..51bfc755f 100644 --- a/vllm/entrypoints/serve/lora/api_router.py +++ b/vllm/entrypoints/serve/lora/api_router.py @@ -7,11 +7,12 @@ from fastapi import APIRouter, Depends, FastAPI, Request from fastapi.responses import JSONResponse, Response from vllm import envs -from vllm.entrypoints.openai.api_server import models, validate_json_request from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.api_router import models +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.serve.lora.protocol import ( LoadLoRAAdapterRequest, UnloadLoRAAdapterRequest, diff --git a/vllm/entrypoints/serve/tokenize/api_router.py b/vllm/entrypoints/serve/tokenize/api_router.py index 7b0b466ab..66d34ef11 100644 --- a/vllm/entrypoints/serve/tokenize/api_router.py +++ b/vllm/entrypoints/serve/tokenize/api_router.py @@ -9,10 +9,10 @@ from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from typing_extensions import assert_never -from vllm.entrypoints.openai.api_server import validate_json_request from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) +from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.serve.tokenize.protocol import ( DetokenizeRequest, DetokenizeResponse, diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index a9f375163..b57c18bf5 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ErrorResponse, ) from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.serve.tokenize.protocol import ( DetokenizeRequest, diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 1134d49cf..9fb21484f 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -31,11 +31,13 @@ if TYPE_CHECKING: from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) - from vllm.entrypoints.openai.engine.protocol import ( + from vllm.entrypoints.openai.completion.protocol import ( CompletionRequest, + ) + from vllm.entrypoints.openai.engine.protocol import ( StreamOptions, ) - from vllm.entrypoints.openai.serving_models import LoRAModulePath + from vllm.entrypoints.openai.models.protocol import LoRAModulePath else: ChatCompletionRequest = object CompletionRequest = object @@ -281,7 +283,7 @@ def should_include_usage( def process_lora_modules( args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None ) -> list[LoRAModulePath]: - from vllm.entrypoints.openai.serving_models import LoRAModulePath + from vllm.entrypoints.openai.models.serving import LoRAModulePath lora_modules = args_lora_modules if default_mm_loras: