[Refactor] [9/N] to simplify the vLLM openai translations serving ar chitecture (#32313)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -17,12 +17,12 @@ from argparse import Namespace
|
||||
from collections.abc import AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any
|
||||
from typing import Any
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import pydantic
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -49,10 +49,6 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
CompletionResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponseVariant,
|
||||
TranslationRequest,
|
||||
TranslationResponseVariant,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.orca_metrics import metrics_header
|
||||
@@ -62,7 +58,7 @@ from vllm.entrypoints.openai.serving_models import (
|
||||
BaseModelPath,
|
||||
OpenAIServingModels,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
from vllm.entrypoints.openai.translations.serving import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
@@ -239,10 +235,6 @@ def models(request: Request) -> OpenAIServingModels:
|
||||
return request.app.state.openai_serving_models
|
||||
|
||||
|
||||
def responses(request: Request) -> OpenAIServingResponses | None:
|
||||
return request.app.state.openai_serving_responses
|
||||
|
||||
|
||||
def messages(request: Request) -> AnthropicServingMessages:
|
||||
return request.app.state.anthropic_serving_messages
|
||||
|
||||
@@ -259,22 +251,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
def translation(request: Request) -> OpenAIServingTranslation:
|
||||
return request.app.state.openai_serving_translation
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
@router.get("/load")
|
||||
async def get_server_load_metrics(request: Request):
|
||||
# This endpoint returns the current server load metrics.
|
||||
@@ -410,84 +390,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_transcriptions(
|
||||
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
|
||||
):
|
||||
handler = transcription(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Transcriptions API"
|
||||
)
|
||||
|
||||
audio_data = await request.file.read()
|
||||
try:
|
||||
generator = await handler.create_transcription(audio_data, request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranscriptionResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/translations",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_translations(
|
||||
request: Annotated[TranslationRequest, Form()], raw_request: Request
|
||||
):
|
||||
handler = translation(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Translations API"
|
||||
)
|
||||
|
||||
audio_data = await request.file.read()
|
||||
try:
|
||||
generator = await handler.create_translation(audio_data, request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranslationResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
@@ -741,6 +643,11 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
|
||||
register_responses_api_router(app)
|
||||
from vllm.entrypoints.openai.translations.api_router import (
|
||||
attach_router as register_translations_api_router,
|
||||
)
|
||||
|
||||
register_translations_api_router(app)
|
||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||
|
||||
register_sagemaker_routes(router)
|
||||
|
||||
Reference in New Issue
Block a user