[Frontend] Frontend will only attach supported tasks corresponding entrypoints. (#33139)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -112,18 +112,14 @@ async def test_long_audio_request(mary_had_lamb, whisper_client):
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_endpoints(whisper_client):
|
||||
# text to text model
|
||||
res = await whisper_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "system", "content": "You are a helpful assistant."}],
|
||||
)
|
||||
err = res.error
|
||||
assert err["code"] == 400
|
||||
assert err["message"] == "The model does not support Chat Completions API"
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await whisper_client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "system", "content": "You are a helpful assistant."}],
|
||||
)
|
||||
|
||||
res = await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello")
|
||||
err = res.error
|
||||
assert err["code"] == 400
|
||||
assert err["message"] == "The model does not support Completions API"
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await whisper_client.completions.create(model=MODEL_NAME, prompt="Hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import httpx
|
||||
import librosa
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import soundfile as sf
|
||||
@@ -52,12 +53,11 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
|
||||
model_name, _get_server_args(rocm_aiter_fa_attention)
|
||||
) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
res = await client.audio.translations.create(
|
||||
model=model_name, file=foscolo, temperature=0.0
|
||||
)
|
||||
err = res.error
|
||||
assert err["code"] == 400 and not res.text
|
||||
assert err["message"] == "The model does not support Translations API"
|
||||
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await client.audio.translations.create(
|
||||
model=model_name, file=foscolo, temperature=0.0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -401,7 +401,7 @@ async def test_score(server: RemoteOpenAIServer, model_name: str):
|
||||
"documents": "pong",
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["detail"] == "Not Found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -416,7 +416,7 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str):
|
||||
"documents": ["pong"],
|
||||
},
|
||||
)
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["detail"] == "Not Found"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -33,14 +33,10 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
@@ -50,12 +46,6 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import (
|
||||
OpenAIServingModels,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.translations.serving import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
ScalingMiddleware,
|
||||
)
|
||||
@@ -70,6 +60,7 @@ from vllm.entrypoints.utils import (
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
@@ -513,7 +504,7 @@ def _log_non_streaming_response(response_body: list) -> None:
|
||||
logger.info("response_body={<binary_data>}")
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(
|
||||
openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
|
||||
@@ -523,52 +514,44 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.state.args = args
|
||||
app.include_router(router)
|
||||
|
||||
from vllm.entrypoints.serve import register_vllm_serve_api_routers
|
||||
|
||||
register_vllm_serve_api_routers(app)
|
||||
from vllm.entrypoints.openai.chat_completion.api_router import (
|
||||
attach_router as register_chat_api_router,
|
||||
)
|
||||
|
||||
register_chat_api_router(app)
|
||||
|
||||
from vllm.entrypoints.openai.responses.api_router import (
|
||||
attach_router as register_responses_api_router,
|
||||
)
|
||||
|
||||
register_responses_api_router(app)
|
||||
from vllm.entrypoints.openai.translations.api_router import (
|
||||
attach_router as register_translations_api_router,
|
||||
)
|
||||
|
||||
register_translations_api_router(app)
|
||||
|
||||
from vllm.entrypoints.openai.completion.api_router import (
|
||||
attach_router as register_completion_api_router,
|
||||
)
|
||||
|
||||
register_completion_api_router(app)
|
||||
from vllm.entrypoints.anthropic.api_router import (
|
||||
attach_router as register_anthropic_api_router,
|
||||
)
|
||||
|
||||
register_anthropic_api_router(app)
|
||||
from vllm.entrypoints.openai.models.api_router import (
|
||||
attach_router as register_models_api_router,
|
||||
)
|
||||
|
||||
register_models_api_router(app)
|
||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||
|
||||
register_sagemaker_routes(router)
|
||||
app.include_router(router)
|
||||
from vllm.entrypoints.sagemaker.api_router import (
|
||||
attach_router as register_sagemaker_api_router,
|
||||
)
|
||||
|
||||
register_sagemaker_api_router(app, supported_tasks)
|
||||
|
||||
if "generate" in supported_tasks:
|
||||
from vllm.entrypoints.openai.generate.api_router import (
|
||||
register_generate_api_routers,
|
||||
)
|
||||
|
||||
register_generate_api_routers(app)
|
||||
|
||||
if "transcription" in supported_tasks:
|
||||
from vllm.entrypoints.openai.translations.api_router import (
|
||||
attach_router as register_translations_api_router,
|
||||
)
|
||||
|
||||
register_translations_api_router(app)
|
||||
|
||||
if any(task in POOLING_TASKS for task in supported_tasks):
|
||||
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||
|
||||
register_pooling_api_routers(app, supported_tasks)
|
||||
|
||||
app.root_path = args.root_path
|
||||
|
||||
from vllm.entrypoints.pooling import register_pooling_api_routers
|
||||
|
||||
register_pooling_api_routers(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
@@ -673,6 +656,7 @@ async def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
) -> None:
|
||||
vllm_config = engine_client.vllm_config
|
||||
|
||||
@@ -694,28 +678,9 @@ async def init_app_state(
|
||||
state.log_stats = not args.disable_log_stats
|
||||
state.vllm_config = vllm_config
|
||||
state.args = args
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
if args.tool_server == "demo":
|
||||
tool_server: ToolServer | None = DemoToolServer()
|
||||
assert isinstance(tool_server, DemoToolServer)
|
||||
await tool_server.init_and_validate()
|
||||
elif args.tool_server:
|
||||
tool_server = MCPToolServer()
|
||||
await tool_server.add_tool_server(args.tool_server)
|
||||
else:
|
||||
tool_server = None
|
||||
|
||||
# Merge default_mm_loras into the static lora_modules
|
||||
default_mm_loras = (
|
||||
vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
default_mm_loras = (
|
||||
vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None
|
||||
@@ -729,66 +694,6 @@ async def init_app_state(
|
||||
lora_modules=lora_modules,
|
||||
)
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
state.openai_serving_responses = (
|
||||
OpenAIServingResponses(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
tool_server=tool_server,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_chat = (
|
||||
OpenAIServingChat(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
default_chat_template_kwargs=args.default_chat_template_kwargs,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
enable_log_deltas=args.enable_log_deltas,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# Warm up chat template processing to avoid first-request latency
|
||||
if state.openai_serving_chat is not None:
|
||||
await state.openai_serving_chat.warmup()
|
||||
state.openai_serving_completion = (
|
||||
OpenAIServingCompletion(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
@@ -798,64 +703,27 @@ async def init_app_state(
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
state.openai_serving_transcription = (
|
||||
OpenAIServingTranscription(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_translation = (
|
||||
OpenAIServingTranslation(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.anthropic_serving_messages = (
|
||||
AnthropicServingMessages(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.serving_tokens = (
|
||||
ServingTokens(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
force_no_detokenize=args.tokens_only,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
from vllm.entrypoints.pooling import init_pooling_state
|
||||
if "generate" in supported_tasks:
|
||||
from vllm.entrypoints.openai.generate.api_router import init_generate_state
|
||||
|
||||
await init_pooling_state(engine_client, state, args)
|
||||
await init_generate_state(
|
||||
engine_client, state, args, request_logger, supported_tasks
|
||||
)
|
||||
|
||||
if "transcription" in supported_tasks:
|
||||
from vllm.entrypoints.openai.translations.api_router import (
|
||||
init_transcription_state,
|
||||
)
|
||||
|
||||
init_transcription_state(
|
||||
engine_client, state, args, request_logger, supported_tasks
|
||||
)
|
||||
|
||||
if any(task in POOLING_TASKS for task in supported_tasks):
|
||||
from vllm.entrypoints.pooling import init_pooling_state
|
||||
|
||||
init_pooling_state(engine_client, state, args, request_logger, supported_tasks)
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
state.server_load_metrics = 0
|
||||
@@ -972,9 +840,11 @@ async def run_server_worker(
|
||||
args,
|
||||
client_config=client_config,
|
||||
) as engine_client:
|
||||
app = build_app(args)
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
await init_app_state(engine_client, app.state, args)
|
||||
app = build_app(args, supported_tasks)
|
||||
await init_app_state(engine_client, app.state, args, supported_tasks)
|
||||
|
||||
logger.info(
|
||||
"Starting vLLM API server %d on %s",
|
||||
|
||||
0
vllm/entrypoints/openai/generate/__init__.py
Normal file
0
vllm/entrypoints/openai/generate/__init__.py
Normal file
166
vllm/entrypoints/openai/generate/api_router.py
Normal file
166
vllm/entrypoints/openai/generate/api_router.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
|
||||
def register_generate_api_routers(app: FastAPI):
|
||||
from vllm.entrypoints.openai.chat_completion.api_router import (
|
||||
attach_router as register_chat_api_router,
|
||||
)
|
||||
|
||||
register_chat_api_router(app)
|
||||
|
||||
from vllm.entrypoints.openai.responses.api_router import (
|
||||
attach_router as register_responses_api_router,
|
||||
)
|
||||
|
||||
register_responses_api_router(app)
|
||||
|
||||
from vllm.entrypoints.openai.completion.api_router import (
|
||||
attach_router as register_completion_api_router,
|
||||
)
|
||||
|
||||
register_completion_api_router(app)
|
||||
|
||||
from vllm.entrypoints.anthropic.api_router import (
|
||||
attach_router as register_anthropic_api_router,
|
||||
)
|
||||
|
||||
register_anthropic_api_router(app)
|
||||
|
||||
|
||||
async def init_generate_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.mcp.tool_server import (
|
||||
DemoToolServer,
|
||||
MCPToolServer,
|
||||
ToolServer,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
|
||||
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
||||
|
||||
if args.tool_server == "demo":
|
||||
tool_server: ToolServer | None = DemoToolServer()
|
||||
assert isinstance(tool_server, DemoToolServer)
|
||||
await tool_server.init_and_validate()
|
||||
elif args.tool_server:
|
||||
tool_server = MCPToolServer()
|
||||
await tool_server.add_tool_server(args.tool_server)
|
||||
else:
|
||||
tool_server = None
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
state.openai_serving_responses = (
|
||||
OpenAIServingResponses(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
tool_server=tool_server,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_chat = (
|
||||
OpenAIServingChat(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
default_chat_template_kwargs=args.default_chat_template_kwargs,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
enable_log_deltas=args.enable_log_deltas,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# Warm up chat template processing to avoid first-request latency
|
||||
if state.openai_serving_chat is not None:
|
||||
await state.openai_serving_chat.warmup()
|
||||
state.openai_serving_completion = (
|
||||
OpenAIServingCompletion(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.anthropic_serving_messages = (
|
||||
AnthropicServingMessages(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.serving_tokens = (
|
||||
ServingTokens(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
force_no_detokenize=args.tokens_only,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -25,6 +25,17 @@ from vllm.entrypoints.utils import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
@@ -115,3 +126,34 @@ async def create_translations(
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
def init_transcription_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
state.openai_serving_transcription = (
|
||||
OpenAIServingTranscription(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_translation = (
|
||||
OpenAIServingTranslation(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -11,40 +11,54 @@ if TYPE_CHECKING:
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
SupportedTask = object
|
||||
|
||||
|
||||
def register_pooling_api_routers(app: FastAPI):
|
||||
from vllm.entrypoints.pooling.classify.api_router import router as classify_router
|
||||
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
|
||||
def register_pooling_api_routers(
|
||||
app: FastAPI, supported_tasks: tuple["SupportedTask", ...]
|
||||
):
|
||||
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
|
||||
app.include_router(classify_router)
|
||||
app.include_router(embed_router)
|
||||
app.include_router(score_router)
|
||||
app.include_router(pooling_router)
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.classify.api_router import (
|
||||
router as classify_router,
|
||||
)
|
||||
|
||||
async def init_pooling_state(
|
||||
engine_client: "EngineClient", state: "State", args: "Namespace"
|
||||
app.include_router(classify_router)
|
||||
|
||||
if "embed" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.embed.api_router import router as embed_router
|
||||
|
||||
app.include_router(embed_router)
|
||||
|
||||
if "score" in supported_tasks or "embed" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
|
||||
app.include_router(score_router)
|
||||
|
||||
|
||||
def init_pooling_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
state.openai_serving_pooling = (
|
||||
(
|
||||
OpenAIServingPooling(
|
||||
|
||||
160
vllm/entrypoints/sagemaker/api_router.py
Normal file
160
vllm/entrypoints/sagemaker/api_router.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm.entrypoints.openai.api_server import base
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
|
||||
def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
|
||||
# NOTE: Items defined earlier take higher priority
|
||||
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
|
||||
|
||||
if "generate" in supported_tasks:
|
||||
from vllm.entrypoints.openai.chat_completion.api_router import (
|
||||
chat,
|
||||
create_chat_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.api_router import (
|
||||
completion,
|
||||
create_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||
(CompletionRequest, (completion, create_completion)),
|
||||
]
|
||||
|
||||
if "embed" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.embed.api_router import (
|
||||
create_embedding,
|
||||
embedding,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(EmbeddingRequest, (embedding, create_embedding)),
|
||||
]
|
||||
|
||||
if "classify" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.classify.api_router import (
|
||||
classify,
|
||||
create_classify,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(ClassificationRequest, (classify, create_classify)),
|
||||
]
|
||||
|
||||
if "score" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(RerankRequest, (rerank, do_rerank)),
|
||||
]
|
||||
|
||||
if "score" in supported_tasks or "embed" in supported_tasks:
|
||||
from vllm.entrypoints.pooling.score.api_router import create_score, score
|
||||
from vllm.entrypoints.pooling.score.protocol import ScoreRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(ScoreRequest, (score, create_score)),
|
||||
]
|
||||
|
||||
if any(task in POOLING_TASKS for task in supported_tasks):
|
||||
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(PoolingRequest, (pooling, create_pooling)),
|
||||
]
|
||||
|
||||
return INVOCATION_TYPES
|
||||
|
||||
|
||||
def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]):
|
||||
router = APIRouter()
|
||||
|
||||
# NOTE: Construct the TypeAdapters only once
|
||||
INVOCATION_TYPES = get_invocation_types(supported_tasks)
|
||||
INVOCATION_VALIDATORS = [
|
||||
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||
]
|
||||
|
||||
@router.post("/ping", response_class=Response)
|
||||
@router.get("/ping", response_class=Response)
|
||||
@sagemaker_standards.register_ping_handler
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
@router.post(
|
||||
"/invocations",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@sagemaker_standards.register_invocation_handler
|
||||
@sagemaker_standards.stateful_session_manager()
|
||||
@sagemaker_standards.inject_adapter_id(adapter_path="model")
|
||||
async def invocations(raw_request: Request):
|
||||
"""For SageMaker, routes requests based on the request type."""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
|
||||
valid_endpoints = [
|
||||
(validator, endpoint)
|
||||
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
|
||||
if get_handler(raw_request) is not None
|
||||
]
|
||||
|
||||
for request_validator, endpoint in valid_endpoints:
|
||||
try:
|
||||
request = request_validator.validate_python(body)
|
||||
except pydantic.ValidationError:
|
||||
continue
|
||||
|
||||
return await endpoint(request, raw_request)
|
||||
|
||||
type_names = [
|
||||
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
||||
for validator, _ in valid_endpoints
|
||||
]
|
||||
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
|
||||
res = base(raw_request).create_error_response(message=msg)
|
||||
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
|
||||
|
||||
app.include_router(router)
|
||||
@@ -1,126 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
base,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.api_router import (
|
||||
chat,
|
||||
create_chat_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.api_router import (
|
||||
completion,
|
||||
create_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.protocol import (
|
||||
CompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
|
||||
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
|
||||
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
|
||||
from vllm.entrypoints.pooling.score.api_router import (
|
||||
create_score,
|
||||
do_rerank,
|
||||
rerank,
|
||||
score,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
# NOTE: Items defined earlier take higher priority
|
||||
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
||||
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||
(CompletionRequest, (completion, create_completion)),
|
||||
(EmbeddingRequest, (embedding, create_embedding)),
|
||||
(ClassificationRequest, (classify, create_classify)),
|
||||
(ScoreRequest, (score, create_score)),
|
||||
(RerankRequest, (rerank, do_rerank)),
|
||||
(PoolingRequest, (pooling, create_pooling)),
|
||||
]
|
||||
|
||||
# NOTE: Construct the TypeAdapters only once
|
||||
INVOCATION_VALIDATORS = [
|
||||
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||
]
|
||||
|
||||
|
||||
def register_sagemaker_routes(router: APIRouter):
|
||||
@router.post("/ping", response_class=Response)
|
||||
@router.get("/ping", response_class=Response)
|
||||
@sagemaker_standards.register_ping_handler
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
@router.post(
|
||||
"/invocations",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@sagemaker_standards.register_invocation_handler
|
||||
@sagemaker_standards.stateful_session_manager()
|
||||
@sagemaker_standards.inject_adapter_id(adapter_path="model")
|
||||
async def invocations(raw_request: Request):
|
||||
"""For SageMaker, routes requests based on the request type."""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
|
||||
valid_endpoints = [
|
||||
(validator, endpoint)
|
||||
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
|
||||
if get_handler(raw_request) is not None
|
||||
]
|
||||
|
||||
for request_validator, endpoint in valid_endpoints:
|
||||
try:
|
||||
request = request_validator.validate_python(body)
|
||||
except pydantic.ValidationError:
|
||||
continue
|
||||
|
||||
return await endpoint(request, raw_request)
|
||||
|
||||
type_names = [
|
||||
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
||||
for validator, _ in valid_endpoints
|
||||
]
|
||||
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
|
||||
res = base(raw_request).create_error_response(message=msg)
|
||||
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user