diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2170765e8..1ce706abc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -8,6 +8,7 @@ import os import signal import socket import tempfile +import warnings from argparse import Namespace from collections.abc import AsyncIterator from contextlib import asynccontextmanager @@ -62,6 +63,8 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger("vllm.entrypoints.openai.api_server") +_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",) + @asynccontextmanager async def build_async_engine_client( @@ -152,7 +155,19 @@ async def build_async_engine_client_from_engine_args( async_llm.shutdown() -def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI: +def build_app( + args: Namespace, supported_tasks: tuple["SupportedTask", ...] | None = None +) -> FastAPI: + if supported_tasks is None: + warnings.warn( + "The 'supported_tasks' parameter was not provided to " + "build_app and will be required in a future version. " + "Defaulting to ('generate',).", + DeprecationWarning, + stacklevel=2, + ) + supported_tasks = _FALLBACK_SUPPORTED_TASKS + if args.disable_fastapi_docs: app = FastAPI( openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan @@ -263,9 +278,18 @@ async def init_app_state( engine_client: EngineClient, state: State, args: Namespace, - supported_tasks: tuple["SupportedTask", ...], + supported_tasks: tuple["SupportedTask", ...] | None = None, ) -> None: vllm_config = engine_client.vllm_config + if supported_tasks is None: + warnings.warn( + "The 'supported_tasks' parameter was not provided to " + "init_app_state and will be required in a future version. " + "Please pass 'supported_tasks' explicitly.", + DeprecationWarning, + stacklevel=2, + ) + supported_tasks = _FALLBACK_SUPPORTED_TASKS if args.served_model_name is not None: served_model_names = args.served_model_name