[Frontend]Add support for transcriptions and translations to run_batch (#33934)

Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Pooya Davoodi
2026-02-07 05:24:57 -08:00
committed by GitHub
parent 4df44c16ba
commit 2cb2340f7a
3 changed files with 555 additions and 101 deletions

View File

@@ -7,6 +7,7 @@ import tempfile
import pytest
from vllm.assets.audio import AudioAsset
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
@@ -42,6 +43,27 @@ INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/re
INPUT_REASONING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Solve this math problem: 2+2=?"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "What is the capital of France?"}]}}"""
# This is a valid but minimal audio file for testing
MINIMAL_WAV_BASE64 = "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAZGF0YQAAAAA="
INPUT_TRANSCRIPTION_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", '
'"body": {{"model": "openai/whisper-large-v3", "file_url": "data:audio/wav;base64,{}", '
'"response_format": "json"}}}}\n'
).format(MINIMAL_WAV_BASE64)
INPUT_TRANSCRIPTION_HTTP_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", '
'"body": {{"model": "openai/whisper-large-v3", "file_url": "{}", '
'"response_format": "json"}}}}\n'
).format(AudioAsset("mary_had_lamb").url)
INPUT_TRANSLATION_BATCH = (
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/translations", '
'"body": {{"model": "openai/whisper-small", "file_url": "{}", '
'"response_format": "text", "language": "it", "to_language": "en", '
'"temperature": 0.0}}}}\n'
).format(AudioAsset("mary_had_lamb").url)
def test_empty_file():
with (
@@ -238,3 +260,121 @@ def test_reasoning_parser():
]
assert reasoning is not None
assert len(reasoning) > 0
def test_transcription():
with (
tempfile.NamedTemporaryFile("w") as input_file,
tempfile.NamedTemporaryFile("r") as output_file,
):
input_file.write(INPUT_TRANSCRIPTION_BATCH)
input_file.flush()
proc = subprocess.Popen(
[
"vllm",
"run-batch",
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
"openai/whisper-large-v3",
],
)
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
print(f"\n\ncontents: {contents}\n\n")
for line in contents.strip().split("\n"):
BatchRequestOutput.model_validate_json(line)
line_dict = json.loads(line)
assert isinstance(line_dict, dict)
assert line_dict["error"] is None
response_body = line_dict["response"]["body"]
assert response_body is not None
assert "text" in response_body
assert "usage" in response_body
def test_transcription_http_url():
with (
tempfile.NamedTemporaryFile("w") as input_file,
tempfile.NamedTemporaryFile("r") as output_file,
):
input_file.write(INPUT_TRANSCRIPTION_HTTP_BATCH)
input_file.flush()
proc = subprocess.Popen(
[
"vllm",
"run-batch",
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
"openai/whisper-large-v3",
],
)
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
for line in contents.strip().split("\n"):
BatchRequestOutput.model_validate_json(line)
line_dict = json.loads(line)
assert isinstance(line_dict, dict)
assert line_dict["error"] is None
response_body = line_dict["response"]["body"]
assert response_body is not None
assert "text" in response_body
assert "usage" in response_body
transcription_text = response_body["text"]
assert "Mary had a little lamb" in transcription_text
def test_translation():
with (
tempfile.NamedTemporaryFile("w") as input_file,
tempfile.NamedTemporaryFile("r") as output_file,
):
input_file.write(INPUT_TRANSLATION_BATCH)
input_file.flush()
proc = subprocess.Popen(
[
"vllm",
"run-batch",
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
"openai/whisper-small",
],
)
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"
contents = output_file.read()
for line in contents.strip().split("\n"):
BatchRequestOutput.model_validate_json(line)
line_dict = json.loads(line)
assert isinstance(line_dict, dict)
assert line_dict["error"] is None
response_body = line_dict["response"]["body"]
assert response_body is not None
assert "text" in response_body
translation_text = response_body["text"]
translation_text_lower = str(translation_text).strip().lower()
assert "mary" in translation_text_lower or "lamb" in translation_text_lower

View File

@@ -2,17 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import tempfile
from argparse import Namespace
from collections.abc import Awaitable, Callable
from http import HTTPStatus
from io import StringIO
from io import BytesIO, StringIO
from typing import Any, TypeAlias
from urllib.parse import urlparse
import aiohttp
import torch
from fastapi import UploadFile
from prometheus_client import start_http_server
from pydantic import TypeAdapter, field_validator
from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo
from tqdm import tqdm
@@ -25,12 +28,28 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
OpenAIBaseModel,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationRequest,
TranslationResponse,
TranslationResponseVerbose,
)
from vllm.entrypoints.openai.speech_to_text.serving import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
@@ -41,6 +60,7 @@ from vllm.entrypoints.pooling.score.protocol import (
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import SupportedTask
from vllm.utils import random_uuid
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
@@ -48,8 +68,73 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
class BatchTranscriptionRequest(TranscriptionRequest):
"""
Batch transcription request that uses file_url instead of file.
This class extends TranscriptionRequest but replaces the file field
with file_url to support batch processing from audio files written in JSON format.
"""
file_url: str = Field(
...,
description=(
"Either a URL of the audio or a data URL with base64 encoded audio data. "
),
)
# Override file to be optional and unused for batch processing
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
@model_validator(mode="before")
@classmethod
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
)
return data
class BatchTranslationRequest(TranslationRequest):
"""
Batch translation request that uses file_url instead of file.
This class extends TranslationRequest but replaces the file field
with file_url to support batch processing from audio files written in JSON format.
"""
file_url: str = Field(
...,
description=(
"Either a URL of the audio or a data URL with base64 encoded audio data. "
),
)
# Override file to be optional and unused for batch processing
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
@model_validator(mode="before")
@classmethod
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
)
return data
BatchRequestInputBody: TypeAlias = (
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
ChatCompletionRequest
| EmbeddingRequest
| ScoreRequest
| RerankRequest
| BatchTranscriptionRequest
| BatchTranslationRequest
)
@@ -88,6 +173,10 @@ class BatchRequestInput(OpenAIBaseModel):
return TypeAdapter(ScoreRequest).validate_python(value)
if url.endswith("/rerank"):
return RerankRequest.model_validate(value)
if url == "/v1/audio/transcriptions":
return BatchTranscriptionRequest.model_validate(value)
if url == "/v1/audio/translations":
return BatchTranslationRequest.model_validate(value)
return TypeAdapter(BatchRequestInputBody).validate_python(value)
@@ -104,6 +193,10 @@ class BatchResponseData(OpenAIBaseModel):
| EmbeddingResponse
| ScoreResponse
| RerankResponse
| TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| None
) = None
@@ -361,6 +454,49 @@ async def write_file(
await write_local_file(path_or_url, batch_outputs)
async def download_bytes_from_url(url: str) -> bytes:
"""
Download data from a URL or decode from a data URL.
Args:
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
Returns:
Data as bytes
"""
parsed = urlparse(url)
# Handle data URLs (base64 encoded)
if parsed.scheme == "data":
# Format: data:...;base64,<base64_data>
if "," in url:
header, data = url.split(",", 1)
if "base64" in header:
return base64.b64decode(data)
else:
raise ValueError(f"Unsupported data URL encoding: {header}")
else:
raise ValueError(f"Invalid data URL format: {url}")
# Handle HTTP/HTTPS URLs
elif parsed.scheme in ("http", "https"):
async with (
aiohttp.ClientSession() as session,
session.get(url) as resp,
):
if resp.status != 200:
raise Exception(
f"Failed to download data from URL: {url}. Status: {resp.status}"
)
return await resp.read()
else:
raise ValueError(
f"Unsupported URL scheme: {parsed.scheme}. "
"Supported schemes: http, https, data"
)
def make_error_request_output(
request: BatchRequestInput, error_msg: str
) -> BatchRequestOutput:
@@ -391,7 +527,16 @@ async def run_request(
if isinstance(
response,
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse),
(
ChatCompletionResponse,
EmbeddingResponse,
ScoreResponse,
RerankResponse,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationResponse,
TranslationResponseVerbose,
),
):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
@@ -420,38 +565,130 @@ async def run_request(
return batch_output
def validate_run_batch_args(args):
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
def handle_endpoint_request(
request: BatchRequestInput,
tracker: BatchProgressTracker,
url_matcher: Callable[[str], bool],
handler_getter: Callable[[], Callable | None],
wrapper_fn: Callable[[Callable], Callable] | None = None,
) -> Awaitable[BatchRequestOutput] | None:
"""
Generic handler for endpoint requests.
Args:
request: The batch request input
tracker: Progress tracker for the batch
url_matcher: Function that takes a URL and returns True if it matches
handler_getter: Function that returns the handler function or None
wrapper_fn: Optional function to wrap the handler (e.g., for transcriptions)
Returns:
Awaitable[BatchRequestOutput] if the request was handled,
None if URL didn't match
"""
if not url_matcher(request.url):
return None
handler_fn = handler_getter()
if handler_fn is None:
error_msg = f"Model does not support endpoint: {request.url}"
return make_async_error_request_output(request, error_msg=error_msg)
# Apply wrapper if provided (e.g., for transcriptions/translations)
if wrapper_fn is not None:
handler_fn = wrapper_fn(handler_fn)
tracker.submitted()
return run_request(handler_fn, request, tracker)
async def run_batch(
def make_transcription_wrapper(is_translation: bool):
"""
Factory function to create a wrapper for transcription/translation handlers.
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
to TranscriptionRequest or TranslationRequest and calls the appropriate handler.
Args:
is_translation: If True, process as translation; otherwise process
as transcription
Returns:
A function that takes a handler and returns a wrapped handler
"""
def wrapper(handler_fn: Callable):
async def transcription_wrapper(
batch_request_body: (BatchTranscriptionRequest | BatchTranslationRequest),
) -> (
TranscriptionResponse
| TranscriptionResponseVerbose
| TranslationResponse
| TranslationResponseVerbose
| ErrorResponse
):
try:
# Download data from URL
audio_data = await download_bytes_from_url(batch_request_body.file_url)
# Create a mock file from the downloaded audio data
mock_file = UploadFile(
file=BytesIO(audio_data),
filename="audio.bin",
)
# Convert batch request to regular request
# by copying all fields except file_url and setting file to mock_file
request_dict = batch_request_body.model_dump(exclude={"file_url"})
request_dict["file"] = mock_file
if is_translation:
# Create TranslationRequest from BatchTranslationRequest
translation_request = TranslationRequest.model_validate(
request_dict
)
return await handler_fn(audio_data, translation_request)
else:
# Create TranscriptionRequest from BatchTranscriptionRequest
transcription_request = TranscriptionRequest.model_validate(
request_dict
)
return await handler_fn(audio_data, transcription_request)
except Exception as e:
operation = "translation" if is_translation else "transcription"
return ErrorResponse(
error=ErrorInfo(
message=f"Failed to process {operation}: {str(e)}",
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST.value,
)
)
return transcription_wrapper
return wrapper
def build_endpoint_registry(
engine_client: EngineClient,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
base_model_paths: list[BaseModelPath],
request_logger: RequestLogger | None,
supported_tasks: tuple[SupportedTask, ...],
) -> dict[str, dict[str, Any]]:
"""
Build the endpoint registry with all serving objects and handler configurations.
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
Args:
engine_client: The engine client
args: Command line arguments
base_model_paths: List of base model paths
request_logger: Optional request logger
supported_tasks: Tuple of supported tasks
Returns:
Dictionary mapping endpoint keys to their configurations
"""
model_config = engine_client.model_config
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
@@ -507,6 +744,129 @@ async def run_batch(
else None
)
openai_serving_transcription = (
OpenAIServingTranscription(
engine_client,
openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
# Registry of endpoint configurations
endpoint_registry: dict[str, dict[str, Any]] = {
"completions": {
"url_matcher": lambda url: url == "/v1/chat/completions",
"handler_getter": lambda: (
openai_serving_chat.create_chat_completion
if openai_serving_chat is not None
else None
),
"wrapper_fn": None,
},
"embeddings": {
"url_matcher": lambda url: url == "/v1/embeddings",
"handler_getter": lambda: (
openai_serving_embedding.create_embedding
if openai_serving_embedding is not None
else None
),
"wrapper_fn": None,
},
"score": {
"url_matcher": lambda url: url.endswith("/score"),
"handler_getter": lambda: (
openai_serving_scores.create_score
if openai_serving_scores is not None
else None
),
"wrapper_fn": None,
},
"rerank": {
"url_matcher": lambda url: url.endswith("/rerank"),
"handler_getter": lambda: (
openai_serving_scores.do_rerank
if openai_serving_scores is not None
else None
),
"wrapper_fn": None,
},
"transcriptions": {
"url_matcher": lambda url: url == "/v1/audio/transcriptions",
"handler_getter": lambda: (
openai_serving_transcription.create_transcription
if openai_serving_transcription is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=False),
},
"translations": {
"url_matcher": lambda url: url == "/v1/audio/translations",
"handler_getter": lambda: (
openai_serving_translation.create_translation
if openai_serving_translation is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=True),
},
}
return endpoint_registry
def validate_run_batch_args(args):
valid_reasoning_parsers = ReasoningParserManager.list_registered()
if (
reasoning_parser := args.structured_outputs_config.reasoning_parser
) and reasoning_parser not in valid_reasoning_parsers:
raise KeyError(
f"invalid reasoning parser: {reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
)
async def run_batch(
engine_client: EngineClient,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
endpoint_registry = build_endpoint_registry(
engine_client=engine_client,
args=args,
base_model_paths=base_model_paths,
request_logger=request_logger,
supported_tasks=supported_tasks,
)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
@@ -520,84 +880,32 @@ async def run_batch(
request = BatchRequestInput.model_validate_json(request_json)
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
chat_handler_fn = (
openai_serving_chat.create_chat_completion
if openai_serving_chat is not None
else None
)
if chat_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Chat Completions API",
)
)
continue
# Use the last segment of the URL as the endpoint key.
# More advanced URL matching is done in url_matcher of endpoint_registry.
endpoint_key = request.url.split("/")[-1]
response_futures.append(run_request(chat_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
embed_handler_fn = (
openai_serving_embedding.create_embedding
if openai_serving_embedding is not None
else None
result = None
if endpoint_key in endpoint_registry:
endpoint_config = endpoint_registry[endpoint_key]
result = handle_endpoint_request(
request,
tracker,
url_matcher=endpoint_config["url_matcher"],
handler_getter=endpoint_config["handler_getter"],
wrapper_fn=endpoint_config["wrapper_fn"],
)
if embed_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Embeddings API",
)
)
continue
response_futures.append(run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/score"):
score_handler_fn = (
openai_serving_scores.create_score
if openai_serving_scores is not None
else None
)
if score_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Scores API",
)
)
continue
response_futures.append(run_request(score_handler_fn, request, tracker))
tracker.submitted()
elif request.url.endswith("/rerank"):
rerank_handler_fn = (
openai_serving_scores.do_rerank
if openai_serving_scores is not None
else None
)
if rerank_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Rerank API",
)
)
continue
response_futures.append(run_request(rerank_handler_fn, request, tracker))
tracker.submitted()
if result is not None:
response_futures.append(result)
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
" /score, /rerank ."
"See vllm/entrypoints/openai/api_server.py for supported "
"score/rerank versions.",
" /v1/audio/transcriptions, /v1/audio/translations, /score, "
" /rerank. See vllm/entrypoints/openai/api_server.py "
"for supported score/rerank versions.",
)
)

View File

@@ -54,7 +54,10 @@ class OpenAIServingTranscription(OpenAISpeechToText):
)
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
self,
audio_data: bytes,
request: TranscriptionRequest,
raw_request: Request | None = None,
) -> (
TranscriptionResponse
| TranscriptionResponseVerbose
@@ -124,7 +127,10 @@ class OpenAIServingTranslation(OpenAISpeechToText):
)
async def create_translation(
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
self,
audio_data: bytes,
request: TranslationRequest,
raw_request: Request | None = None,
) -> (
TranslationResponse
| TranslationResponseVerbose