diff --git a/tests/entrypoints/openai/test_transcription_validation_whisper.py b/tests/entrypoints/openai/test_transcription_validation_whisper.py index 67e358ffa..e645dcf78 100644 --- a/tests/entrypoints/openai/test_transcription_validation_whisper.py +++ b/tests/entrypoints/openai/test_transcription_validation_whisper.py @@ -112,18 +112,14 @@ async def test_long_audio_request(mary_had_lamb, whisper_client): @pytest.mark.asyncio async def test_completion_endpoints(whisper_client): # text to text model - res = await whisper_client.chat.completions.create( - model=MODEL_NAME, - messages=[{"role": "system", "content": "You are a helpful assistant."}], - ) - err = res.error - assert err["code"] == 400 - assert err["message"] == "The model does not support Chat Completions API" + with pytest.raises(openai.NotFoundError): + await whisper_client.chat.completions.create( + model=MODEL_NAME, + messages=[{"role": "system", "content": "You are a helpful assistant."}], + ) - res = await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello") - err = res.error - assert err["code"] == 400 - assert err["message"] == "The model does not support Completions API" + with pytest.raises(openai.NotFoundError): + await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello") @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index 23502ea1b..9c33ca421 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -9,6 +9,7 @@ import json import httpx import librosa import numpy as np +import openai import pytest import pytest_asyncio import soundfile as sf @@ -52,12 +53,11 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention): model_name, _get_server_args(rocm_aiter_fa_attention) ) as remote_server: client = remote_server.get_async_client() - res = await client.audio.translations.create( - model=model_name, file=foscolo, temperature=0.0 - ) - err = res.error - assert err["code"] == 400 and not res.text - assert err["message"] == "The model does not support Translations API" + + with pytest.raises(openai.NotFoundError): + await client.audio.translations.create( + model=model_name, file=foscolo, temperature=0.0 + ) @pytest.mark.asyncio diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index 592c862d0..84b017393 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -401,7 +401,7 @@ async def test_score(server: RemoteOpenAIServer, model_name: str): "documents": "pong", }, ) - assert response.json()["error"]["type"] == "BadRequestError" + assert response.json()["detail"] == "Not Found" @pytest.mark.asyncio @@ -416,7 +416,7 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str): "documents": ["pong"], }, ) - assert response.json()["error"]["type"] == "BadRequestError" + assert response.json()["detail"] == "Not Found" @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ff00f1d29..85aec7a88 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -33,14 +33,10 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient -from vllm.entrypoints.anthropic.serving import AnthropicServingMessages from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer -from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args -from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion from vllm.entrypoints.openai.engine.protocol import ( ErrorInfo, ErrorResponse, @@ -50,12 +46,6 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import ( OpenAIServingModels, ) -from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses -from vllm.entrypoints.openai.translations.serving import ( - OpenAIServingTranscription, - OpenAIServingTranslation, -) -from vllm.entrypoints.serve.disagg.serving import ServingTokens from vllm.entrypoints.serve.elastic_ep.middleware import ( ScalingMiddleware, ) @@ -70,6 +60,7 @@ from vllm.entrypoints.utils import ( from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager +from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tool_parsers import ToolParserManager from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -513,7 +504,7 @@ def _log_non_streaming_response(response_body: list) -> None: logger.info("response_body={}") -def build_app(args: Namespace) -> FastAPI: +def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI: if args.disable_fastapi_docs: app = FastAPI( openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan @@ -523,52 +514,44 @@ def build_app(args: Namespace) -> FastAPI: else: app = FastAPI(lifespan=lifespan) app.state.args = args + app.include_router(router) + from vllm.entrypoints.serve import register_vllm_serve_api_routers register_vllm_serve_api_routers(app) - from vllm.entrypoints.openai.chat_completion.api_router import ( - attach_router as register_chat_api_router, - ) - register_chat_api_router(app) - - from vllm.entrypoints.openai.responses.api_router import ( - attach_router as register_responses_api_router, - ) - - register_responses_api_router(app) - from vllm.entrypoints.openai.translations.api_router import ( - attach_router as register_translations_api_router, - ) - - register_translations_api_router(app) - - from vllm.entrypoints.openai.completion.api_router import ( - attach_router as register_completion_api_router, - ) - - register_completion_api_router(app) - from vllm.entrypoints.anthropic.api_router import ( - attach_router as register_anthropic_api_router, - ) - - register_anthropic_api_router(app) from vllm.entrypoints.openai.models.api_router import ( attach_router as register_models_api_router, ) register_models_api_router(app) - from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes - register_sagemaker_routes(router) - app.include_router(router) + from vllm.entrypoints.sagemaker.api_router import ( + attach_router as register_sagemaker_api_router, + ) + + register_sagemaker_api_router(app, supported_tasks) + + if "generate" in supported_tasks: + from vllm.entrypoints.openai.generate.api_router import ( + register_generate_api_routers, + ) + + register_generate_api_routers(app) + + if "transcription" in supported_tasks: + from vllm.entrypoints.openai.translations.api_router import ( + attach_router as register_translations_api_router, + ) + + register_translations_api_router(app) + + if any(task in POOLING_TASKS for task in supported_tasks): + from vllm.entrypoints.pooling import register_pooling_api_routers + + register_pooling_api_routers(app, supported_tasks) app.root_path = args.root_path - - from vllm.entrypoints.pooling import register_pooling_api_routers - - register_pooling_api_routers(app) - app.add_middleware( CORSMiddleware, allow_origins=args.allowed_origins, @@ -673,6 +656,7 @@ async def init_app_state( engine_client: EngineClient, state: State, args: Namespace, + supported_tasks: tuple["SupportedTask", ...], ) -> None: vllm_config = engine_client.vllm_config @@ -694,28 +678,9 @@ async def init_app_state( state.log_stats = not args.disable_log_stats state.vllm_config = vllm_config state.args = args - supported_tasks = await engine_client.get_supported_tasks() - logger.info("Supported tasks: %s", supported_tasks) - resolved_chat_template = load_chat_template(args.chat_template) - if args.tool_server == "demo": - tool_server: ToolServer | None = DemoToolServer() - assert isinstance(tool_server, DemoToolServer) - await tool_server.init_and_validate() - elif args.tool_server: - tool_server = MCPToolServer() - await tool_server.add_tool_server(args.tool_server) - else: - tool_server = None - # Merge default_mm_loras into the static lora_modules - default_mm_loras = ( - vllm_config.lora_config.default_mm_loras - if vllm_config.lora_config is not None - else {} - ) - default_mm_loras = ( vllm_config.lora_config.default_mm_loras if vllm_config.lora_config is not None @@ -729,66 +694,6 @@ async def init_app_state( lora_modules=lora_modules, ) await state.openai_serving_models.init_static_loras() - state.openai_serving_responses = ( - OpenAIServingResponses( - engine_client, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser, - tool_server=tool_server, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - enable_log_outputs=args.enable_log_outputs, - log_error_stack=args.log_error_stack, - ) - if "generate" in supported_tasks - else None - ) - state.openai_serving_chat = ( - OpenAIServingChat( - engine_client, - state.openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - default_chat_template_kwargs=args.default_chat_template_kwargs, - trust_request_chat_template=args.trust_request_chat_template, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, - tool_parser=args.tool_call_parser, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - enable_log_outputs=args.enable_log_outputs, - enable_log_deltas=args.enable_log_deltas, - log_error_stack=args.log_error_stack, - ) - if "generate" in supported_tasks - else None - ) - # Warm up chat template processing to avoid first-request latency - if state.openai_serving_chat is not None: - await state.openai_serving_chat.warmup() - state.openai_serving_completion = ( - OpenAIServingCompletion( - engine_client, - state.openai_serving_models, - request_logger=request_logger, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - log_error_stack=args.log_error_stack, - ) - if "generate" in supported_tasks - else None - ) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, state.openai_serving_models, @@ -798,64 +703,27 @@ async def init_app_state( trust_request_chat_template=args.trust_request_chat_template, log_error_stack=args.log_error_stack, ) - state.openai_serving_transcription = ( - OpenAIServingTranscription( - engine_client, - state.openai_serving_models, - request_logger=request_logger, - log_error_stack=args.log_error_stack, - enable_force_include_usage=args.enable_force_include_usage, - ) - if "transcription" in supported_tasks - else None - ) - state.openai_serving_translation = ( - OpenAIServingTranslation( - engine_client, - state.openai_serving_models, - request_logger=request_logger, - log_error_stack=args.log_error_stack, - enable_force_include_usage=args.enable_force_include_usage, - ) - if "transcription" in supported_tasks - else None - ) - state.anthropic_serving_messages = ( - AnthropicServingMessages( - engine_client, - state.openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - ) - if "generate" in supported_tasks - else None - ) - state.serving_tokens = ( - ServingTokens( - engine_client, - state.openai_serving_models, - request_logger=request_logger, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - log_error_stack=args.log_error_stack, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_log_outputs=args.enable_log_outputs, - force_no_detokenize=args.tokens_only, - ) - if "generate" in supported_tasks - else None - ) - from vllm.entrypoints.pooling import init_pooling_state + if "generate" in supported_tasks: + from vllm.entrypoints.openai.generate.api_router import init_generate_state - await init_pooling_state(engine_client, state, args) + await init_generate_state( + engine_client, state, args, request_logger, supported_tasks + ) + + if "transcription" in supported_tasks: + from vllm.entrypoints.openai.translations.api_router import ( + init_transcription_state, + ) + + init_transcription_state( + engine_client, state, args, request_logger, supported_tasks + ) + + if any(task in POOLING_TASKS for task in supported_tasks): + from vllm.entrypoints.pooling import init_pooling_state + + init_pooling_state(engine_client, state, args, request_logger, supported_tasks) state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -972,9 +840,11 @@ async def run_server_worker( args, client_config=client_config, ) as engine_client: - app = build_app(args) + supported_tasks = await engine_client.get_supported_tasks() + logger.info("Supported tasks: %s", supported_tasks) - await init_app_state(engine_client, app.state, args) + app = build_app(args, supported_tasks) + await init_app_state(engine_client, app.state, args, supported_tasks) logger.info( "Starting vLLM API server %d on %s", diff --git a/vllm/entrypoints/openai/generate/__init__.py b/vllm/entrypoints/openai/generate/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py new file mode 100644 index 000000000..ac74c7582 --- /dev/null +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING + +from fastapi import FastAPI + +if TYPE_CHECKING: + from argparse import Namespace + + from starlette.datastructures import State + + from vllm.engine.protocol import EngineClient + from vllm.entrypoints.logger import RequestLogger + from vllm.tasks import SupportedTask +else: + RequestLogger = object + + +def register_generate_api_routers(app: FastAPI): + from vllm.entrypoints.openai.chat_completion.api_router import ( + attach_router as register_chat_api_router, + ) + + register_chat_api_router(app) + + from vllm.entrypoints.openai.responses.api_router import ( + attach_router as register_responses_api_router, + ) + + register_responses_api_router(app) + + from vllm.entrypoints.openai.completion.api_router import ( + attach_router as register_completion_api_router, + ) + + register_completion_api_router(app) + + from vllm.entrypoints.anthropic.api_router import ( + attach_router as register_anthropic_api_router, + ) + + register_anthropic_api_router(app) + + +async def init_generate_state( + engine_client: "EngineClient", + state: "State", + args: "Namespace", + request_logger: RequestLogger | None, + supported_tasks: tuple["SupportedTask", ...], +): + from vllm.entrypoints.anthropic.serving import AnthropicServingMessages + from vllm.entrypoints.chat_utils import load_chat_template + from vllm.entrypoints.mcp.tool_server import ( + DemoToolServer, + MCPToolServer, + ToolServer, + ) + from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat + from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion + from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses + from vllm.entrypoints.serve.disagg.serving import ServingTokens + + if args.tool_server == "demo": + tool_server: ToolServer | None = DemoToolServer() + assert isinstance(tool_server, DemoToolServer) + await tool_server.init_and_validate() + elif args.tool_server: + tool_server = MCPToolServer() + await tool_server.add_tool_server(args.tool_server) + else: + tool_server = None + resolved_chat_template = load_chat_template(args.chat_template) + + state.openai_serving_responses = ( + OpenAIServingResponses( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + tool_server=tool_server, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_chat = ( + OpenAIServingChat( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + default_chat_template_kwargs=args.default_chat_template_kwargs, + trust_request_chat_template=args.trust_request_chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + enable_log_deltas=args.enable_log_deltas, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + # Warm up chat template processing to avoid first-request latency + if state.openai_serving_chat is not None: + await state.openai_serving_chat.warmup() + state.openai_serving_completion = ( + OpenAIServingCompletion( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.anthropic_serving_messages = ( + AnthropicServingMessages( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "generate" in supported_tasks + else None + ) + state.serving_tokens = ( + ServingTokens( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + log_error_stack=args.log_error_stack, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_log_outputs=args.enable_log_outputs, + force_no_detokenize=args.tokens_only, + ) + if "generate" in supported_tasks + else None + ) diff --git a/vllm/entrypoints/openai/translations/api_router.py b/vllm/entrypoints/openai/translations/api_router.py index dcc64a628..7dd95161f 100644 --- a/vllm/entrypoints/openai/translations/api_router.py +++ b/vllm/entrypoints/openai/translations/api_router.py @@ -3,7 +3,7 @@ from http import HTTPStatus -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import APIRouter, FastAPI, Form, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -25,6 +25,17 @@ from vllm.entrypoints.utils import ( ) from vllm.logger import init_logger +if TYPE_CHECKING: + from argparse import Namespace + + from starlette.datastructures import State + + from vllm.engine.protocol import EngineClient + from vllm.entrypoints.logger import RequestLogger + from vllm.tasks import SupportedTask +else: + RequestLogger = object + logger = init_logger(__name__) router = APIRouter() @@ -115,3 +126,34 @@ async def create_translations( def attach_router(app: FastAPI): app.include_router(router) + + +def init_transcription_state( + engine_client: "EngineClient", + state: "State", + args: "Namespace", + request_logger: RequestLogger | None, + supported_tasks: tuple["SupportedTask", ...], +): + state.openai_serving_transcription = ( + OpenAIServingTranscription( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) + state.openai_serving_translation = ( + OpenAIServingTranslation( + engine_client, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 408542dfa..737f1efe8 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -11,40 +11,54 @@ if TYPE_CHECKING: from starlette.datastructures import State from vllm.engine.protocol import EngineClient + from vllm.entrypoints.logger import RequestLogger + from vllm.tasks import SupportedTask +else: + RequestLogger = object + SupportedTask = object -def register_pooling_api_routers(app: FastAPI): - from vllm.entrypoints.pooling.classify.api_router import router as classify_router - from vllm.entrypoints.pooling.embed.api_router import router as embed_router +def register_pooling_api_routers( + app: FastAPI, supported_tasks: tuple["SupportedTask", ...] +): from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router - from vllm.entrypoints.pooling.score.api_router import router as score_router - app.include_router(classify_router) - app.include_router(embed_router) - app.include_router(score_router) app.include_router(pooling_router) + if "classify" in supported_tasks: + from vllm.entrypoints.pooling.classify.api_router import ( + router as classify_router, + ) -async def init_pooling_state( - engine_client: "EngineClient", state: "State", args: "Namespace" + app.include_router(classify_router) + + if "embed" in supported_tasks: + from vllm.entrypoints.pooling.embed.api_router import router as embed_router + + app.include_router(embed_router) + + if "score" in supported_tasks or "embed" in supported_tasks: + from vllm.entrypoints.pooling.score.api_router import router as score_router + + app.include_router(score_router) + + +def init_pooling_state( + engine_client: "EngineClient", + state: "State", + args: "Namespace", + request_logger: RequestLogger | None, + supported_tasks: tuple["SupportedTask", ...], ): from vllm.entrypoints.chat_utils import load_chat_template - from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.tasks import POOLING_TASKS - supported_tasks = await engine_client.get_supported_tasks() - resolved_chat_template = load_chat_template(args.chat_template) - if args.enable_log_requests: - request_logger = RequestLogger(max_log_len=args.max_log_len) - else: - request_logger = None - state.openai_serving_pooling = ( ( OpenAIServingPooling( diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py new file mode 100644 index 000000000..8e5dd3db2 --- /dev/null +++ b/vllm/entrypoints/sagemaker/api_router.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from collections.abc import Awaitable, Callable +from http import HTTPStatus +from typing import Any + +import model_hosting_container_standards.sagemaker as sagemaker_standards +import pydantic +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import base +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.engine.serving import OpenAIServing +from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.serve.instrumentator.health import health +from vllm.tasks import POOLING_TASKS, SupportedTask + +# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers +# (requires typing_extensions >= 4.13) +RequestType = Any +GetHandlerFn = Callable[[Request], OpenAIServing | None] +EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] + + +def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]): + # NOTE: Items defined earlier take higher priority + INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [] + + if "generate" in supported_tasks: + from vllm.entrypoints.openai.chat_completion.api_router import ( + chat, + create_chat_completion, + ) + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + from vllm.entrypoints.openai.completion.api_router import ( + completion, + create_completion, + ) + from vllm.entrypoints.openai.completion.protocol import CompletionRequest + + INVOCATION_TYPES += [ + (ChatCompletionRequest, (chat, create_chat_completion)), + (CompletionRequest, (completion, create_completion)), + ] + + if "embed" in supported_tasks: + from vllm.entrypoints.pooling.embed.api_router import ( + create_embedding, + embedding, + ) + from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest + + INVOCATION_TYPES += [ + (EmbeddingRequest, (embedding, create_embedding)), + ] + + if "classify" in supported_tasks: + from vllm.entrypoints.pooling.classify.api_router import ( + classify, + create_classify, + ) + from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest + + INVOCATION_TYPES += [ + (ClassificationRequest, (classify, create_classify)), + ] + + if "score" in supported_tasks: + from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank + from vllm.entrypoints.pooling.score.protocol import RerankRequest + + INVOCATION_TYPES += [ + (RerankRequest, (rerank, do_rerank)), + ] + + if "score" in supported_tasks or "embed" in supported_tasks: + from vllm.entrypoints.pooling.score.api_router import create_score, score + from vllm.entrypoints.pooling.score.protocol import ScoreRequest + + INVOCATION_TYPES += [ + (ScoreRequest, (score, create_score)), + ] + + if any(task in POOLING_TASKS for task in supported_tasks): + from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling + from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest + + INVOCATION_TYPES += [ + (PoolingRequest, (pooling, create_pooling)), + ] + + return INVOCATION_TYPES + + +def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]): + router = APIRouter() + + # NOTE: Construct the TypeAdapters only once + INVOCATION_TYPES = get_invocation_types(supported_tasks) + INVOCATION_VALIDATORS = [ + (pydantic.TypeAdapter(request_type), (get_handler, endpoint)) + for request_type, (get_handler, endpoint) in INVOCATION_TYPES + ] + + @router.post("/ping", response_class=Response) + @router.get("/ping", response_class=Response) + @sagemaker_standards.register_ping_handler + async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, + ) + @sagemaker_standards.register_invocation_handler + @sagemaker_standards.stateful_session_manager() + @sagemaker_standards.inject_adapter_id(adapter_path="model") + async def invocations(raw_request: Request): + """For SageMaker, routes requests based on the request type.""" + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + + valid_endpoints = [ + (validator, endpoint) + for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None + ] + + for request_validator, endpoint in valid_endpoints: + try: + request = request_validator.validate_python(body) + except pydantic.ValidationError: + continue + + return await endpoint(request, raw_request) + + type_names = [ + t.__name__ if isinstance(t := validator._type, type) else str(t) + for validator, _ in valid_endpoints + ] + msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" + res = base(raw_request).create_error_response(message=msg) + return JSONResponse(content=res.model_dump(), status_code=res.error.code) + + app.include_router(router) diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py deleted file mode 100644 index b00c6d82c..000000000 --- a/vllm/entrypoints/sagemaker/routes.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json -from collections.abc import Awaitable, Callable -from http import HTTPStatus -from typing import Any - -import model_hosting_container_standards.sagemaker as sagemaker_standards -import pydantic -from fastapi import APIRouter, Depends, HTTPException, Request -from fastapi.responses import JSONResponse, Response - -from vllm.entrypoints.openai.api_server import ( - base, -) -from vllm.entrypoints.openai.chat_completion.api_router import ( - chat, - create_chat_completion, -) -from vllm.entrypoints.openai.chat_completion.protocol import ( - ChatCompletionRequest, -) -from vllm.entrypoints.openai.completion.api_router import ( - completion, - create_completion, -) -from vllm.entrypoints.openai.completion.protocol import ( - CompletionRequest, -) -from vllm.entrypoints.openai.engine.protocol import ( - ErrorResponse, -) -from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.entrypoints.openai.utils import validate_json_request -from vllm.entrypoints.pooling.classify.api_router import classify, create_classify -from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest -from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding -from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest -from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling -from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest -from vllm.entrypoints.pooling.score.api_router import ( - create_score, - do_rerank, - rerank, - score, -) -from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest -from vllm.entrypoints.serve.instrumentator.health import health - -# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers -# (requires typing_extensions >= 4.13) -RequestType = Any -GetHandlerFn = Callable[[Request], OpenAIServing | None] -EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] - -# NOTE: Items defined earlier take higher priority -INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [ - (ChatCompletionRequest, (chat, create_chat_completion)), - (CompletionRequest, (completion, create_completion)), - (EmbeddingRequest, (embedding, create_embedding)), - (ClassificationRequest, (classify, create_classify)), - (ScoreRequest, (score, create_score)), - (RerankRequest, (rerank, do_rerank)), - (PoolingRequest, (pooling, create_pooling)), -] - -# NOTE: Construct the TypeAdapters only once -INVOCATION_VALIDATORS = [ - (pydantic.TypeAdapter(request_type), (get_handler, endpoint)) - for request_type, (get_handler, endpoint) in INVOCATION_TYPES -] - - -def register_sagemaker_routes(router: APIRouter): - @router.post("/ping", response_class=Response) - @router.get("/ping", response_class=Response) - @sagemaker_standards.register_ping_handler - async def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return await health(raw_request) - - @router.post( - "/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, - ) - @sagemaker_standards.register_invocation_handler - @sagemaker_standards.stateful_session_manager() - @sagemaker_standards.inject_adapter_id(adapter_path="model") - async def invocations(raw_request: Request): - """For SageMaker, routes requests based on the request type.""" - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e - - valid_endpoints = [ - (validator, endpoint) - for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS - if get_handler(raw_request) is not None - ] - - for request_validator, endpoint in valid_endpoints: - try: - request = request_validator.validate_python(body) - except pydantic.ValidationError: - continue - - return await endpoint(request, raw_request) - - type_names = [ - t.__name__ if isinstance(t := validator._type, type) else str(t) - for validator, _ in valid_endpoints - ] - msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" - res = base(raw_request).create_error_response(message=msg) - return JSONResponse(content=res.model_dump(), status_code=res.error.code) - - return router