[Frontend] Add /v1/audio/transcriptions OpenAI API endpoint (#12909)

This commit is contained in:
Nicolò Lucchesi
2025-02-13 16:23:45 +01:00
committed by GitHub
parent 37dfa60037
commit d84cef76eb
20 changed files with 910 additions and 19 deletions

View File

@@ -16,10 +16,10 @@ from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -61,6 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
@@ -75,6 +77,8 @@ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
@@ -327,6 +331,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@@ -520,6 +528,31 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/v1/audio/transcriptions")
@with_cancellation
async def create_transcriptions(request: Annotated[TranscriptionRequest,
Form()],
raw_request: Request):
handler = transcription(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Transcriptions API")
audio_data = await request.file.read()
generator = await handler.create_transcription(audio_data, request,
raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TranscriptionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request):
@@ -832,6 +865,12 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.openai_serving_transcription = OpenAIServingTranscription(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.runner_type == "transcription" else None
state.task = model_config.task