diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index d80357a75..26b34a924 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -7,6 +7,7 @@ import tempfile import pytest +from vllm.assets.audio import AudioAsset from vllm.entrypoints.openai.run_batch import BatchRequestOutput MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" @@ -42,6 +43,27 @@ INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/re INPUT_REASONING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Solve this math problem: 2+2=?"}]}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "What is the capital of France?"}]}}""" +# This is a valid but minimal audio file for testing +MINIMAL_WAV_BASE64 = "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAZGF0YQAAAAA=" +INPUT_TRANSCRIPTION_BATCH = ( + '{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", ' + '"body": {{"model": "openai/whisper-large-v3", "file_url": "data:audio/wav;base64,{}", ' + '"response_format": "json"}}}}\n' +).format(MINIMAL_WAV_BASE64) + +INPUT_TRANSCRIPTION_HTTP_BATCH = ( + '{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", ' + '"body": {{"model": "openai/whisper-large-v3", "file_url": "{}", ' + '"response_format": "json"}}}}\n' +).format(AudioAsset("mary_had_lamb").url) + +INPUT_TRANSLATION_BATCH = ( + '{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/translations", ' + '"body": {{"model": "openai/whisper-small", "file_url": "{}", ' + '"response_format": "text", "language": "it", "to_language": "en", ' + '"temperature": 0.0}}}}\n' +).format(AudioAsset("mary_had_lamb").url) + def test_empty_file(): with ( @@ -238,3 +260,121 @@ def test_reasoning_parser(): ] assert reasoning is not None assert len(reasoning) > 0 + + +def test_transcription(): + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): + input_file.write(INPUT_TRANSCRIPTION_BATCH) + input_file.flush() + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "openai/whisper-large-v3", + ], + ) + proc.communicate() + proc.wait() + assert proc.returncode == 0, f"{proc=}" + + contents = output_file.read() + print(f"\n\ncontents: {contents}\n\n") + for line in contents.strip().split("\n"): + BatchRequestOutput.model_validate_json(line) + + line_dict = json.loads(line) + assert isinstance(line_dict, dict) + assert line_dict["error"] is None + + response_body = line_dict["response"]["body"] + assert response_body is not None + assert "text" in response_body + assert "usage" in response_body + + +def test_transcription_http_url(): + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): + input_file.write(INPUT_TRANSCRIPTION_HTTP_BATCH) + input_file.flush() + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "openai/whisper-large-v3", + ], + ) + proc.communicate() + proc.wait() + assert proc.returncode == 0, f"{proc=}" + + contents = output_file.read() + for line in contents.strip().split("\n"): + BatchRequestOutput.model_validate_json(line) + + line_dict = json.loads(line) + assert isinstance(line_dict, dict) + assert line_dict["error"] is None + + response_body = line_dict["response"]["body"] + assert response_body is not None + assert "text" in response_body + assert "usage" in response_body + + transcription_text = response_body["text"] + assert "Mary had a little lamb" in transcription_text + + +def test_translation(): + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): + input_file.write(INPUT_TRANSLATION_BATCH) + input_file.flush() + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "openai/whisper-small", + ], + ) + proc.communicate() + proc.wait() + assert proc.returncode == 0, f"{proc=}" + + contents = output_file.read() + for line in contents.strip().split("\n"): + BatchRequestOutput.model_validate_json(line) + + line_dict = json.loads(line) + assert isinstance(line_dict, dict) + assert line_dict["error"] is None + + response_body = line_dict["response"]["body"] + assert response_body is not None + assert "text" in response_body + + translation_text = response_body["text"] + translation_text_lower = str(translation_text).strip().lower() + assert "mary" in translation_text_lower or "lamb" in translation_text_lower diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 0673bf37a..f3c145a0b 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -2,17 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import base64 import tempfile from argparse import Namespace from collections.abc import Awaitable, Callable from http import HTTPStatus -from io import StringIO +from io import BytesIO, StringIO from typing import Any, TypeAlias +from urllib.parse import urlparse import aiohttp import torch +from fastapi import UploadFile from prometheus_client import start_http_server -from pydantic import TypeAdapter, field_validator +from pydantic import Field, TypeAdapter, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from tqdm import tqdm @@ -25,12 +28,28 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ) from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.engine.protocol import ( + ErrorInfo, ErrorResponse, OpenAIBaseModel, ) from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels -from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse +from vllm.entrypoints.openai.speech_to_text.protocol import ( + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseVerbose, + TranslationRequest, + TranslationResponse, + TranslationResponseVerbose, +) +from vllm.entrypoints.openai.speech_to_text.serving import ( + OpenAIServingTranscription, + OpenAIServingTranslation, +) +from vllm.entrypoints.pooling.embed.protocol import ( + EmbeddingRequest, + EmbeddingResponse, +) from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.score.protocol import ( RerankRequest, @@ -41,6 +60,7 @@ from vllm.entrypoints.pooling.score.protocol import ( from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager +from vllm.tasks import SupportedTask from vllm.utils import random_uuid from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.version import __version__ as VLLM_VERSION @@ -48,8 +68,73 @@ from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) +class BatchTranscriptionRequest(TranscriptionRequest): + """ + Batch transcription request that uses file_url instead of file. + + This class extends TranscriptionRequest but replaces the file field + with file_url to support batch processing from audio files written in JSON format. + """ + + file_url: str = Field( + ..., + description=( + "Either a URL of the audio or a data URL with base64 encoded audio data. " + ), + ) + + # Override file to be optional and unused for batch processing + file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment] + + @model_validator(mode="before") + @classmethod + def validate_no_file(cls, data: Any): + """Ensure file field is not provided in batch requests.""" + if isinstance(data, dict) and "file" in data: + raise ValueError( + "The 'file' field is not supported in batch requests. " + "Use 'file_url' instead." + ) + return data + + +class BatchTranslationRequest(TranslationRequest): + """ + Batch translation request that uses file_url instead of file. + + This class extends TranslationRequest but replaces the file field + with file_url to support batch processing from audio files written in JSON format. + """ + + file_url: str = Field( + ..., + description=( + "Either a URL of the audio or a data URL with base64 encoded audio data. " + ), + ) + + # Override file to be optional and unused for batch processing + file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment] + + @model_validator(mode="before") + @classmethod + def validate_no_file(cls, data: Any): + """Ensure file field is not provided in batch requests.""" + if isinstance(data, dict) and "file" in data: + raise ValueError( + "The 'file' field is not supported in batch requests. " + "Use 'file_url' instead." + ) + return data + + BatchRequestInputBody: TypeAlias = ( - ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest + ChatCompletionRequest + | EmbeddingRequest + | ScoreRequest + | RerankRequest + | BatchTranscriptionRequest + | BatchTranslationRequest ) @@ -88,6 +173,10 @@ class BatchRequestInput(OpenAIBaseModel): return TypeAdapter(ScoreRequest).validate_python(value) if url.endswith("/rerank"): return RerankRequest.model_validate(value) + if url == "/v1/audio/transcriptions": + return BatchTranscriptionRequest.model_validate(value) + if url == "/v1/audio/translations": + return BatchTranslationRequest.model_validate(value) return TypeAdapter(BatchRequestInputBody).validate_python(value) @@ -104,6 +193,10 @@ class BatchResponseData(OpenAIBaseModel): | EmbeddingResponse | ScoreResponse | RerankResponse + | TranscriptionResponse + | TranscriptionResponseVerbose + | TranslationResponse + | TranslationResponseVerbose | None ) = None @@ -361,6 +454,49 @@ async def write_file( await write_local_file(path_or_url, batch_outputs) +async def download_bytes_from_url(url: str) -> bytes: + """ + Download data from a URL or decode from a data URL. + + Args: + url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...) + + Returns: + Data as bytes + """ + parsed = urlparse(url) + + # Handle data URLs (base64 encoded) + if parsed.scheme == "data": + # Format: data:...;base64, + if "," in url: + header, data = url.split(",", 1) + if "base64" in header: + return base64.b64decode(data) + else: + raise ValueError(f"Unsupported data URL encoding: {header}") + else: + raise ValueError(f"Invalid data URL format: {url}") + + # Handle HTTP/HTTPS URLs + elif parsed.scheme in ("http", "https"): + async with ( + aiohttp.ClientSession() as session, + session.get(url) as resp, + ): + if resp.status != 200: + raise Exception( + f"Failed to download data from URL: {url}. Status: {resp.status}" + ) + return await resp.read() + + else: + raise ValueError( + f"Unsupported URL scheme: {parsed.scheme}. " + "Supported schemes: http, https, data" + ) + + def make_error_request_output( request: BatchRequestInput, error_msg: str ) -> BatchRequestOutput: @@ -391,7 +527,16 @@ async def run_request( if isinstance( response, - (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse), + ( + ChatCompletionResponse, + EmbeddingResponse, + ScoreResponse, + RerankResponse, + TranscriptionResponse, + TranscriptionResponseVerbose, + TranslationResponse, + TranslationResponseVerbose, + ), ): batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", @@ -420,38 +565,130 @@ async def run_request( return batch_output -def validate_run_batch_args(args): - valid_reasoning_parsers = ReasoningParserManager.list_registered() - if ( - reasoning_parser := args.structured_outputs_config.reasoning_parser - ) and reasoning_parser not in valid_reasoning_parsers: - raise KeyError( - f"invalid reasoning parser: {reasoning_parser} " - f"(chose from {{ {','.join(valid_reasoning_parsers)} }})" - ) +def handle_endpoint_request( + request: BatchRequestInput, + tracker: BatchProgressTracker, + url_matcher: Callable[[str], bool], + handler_getter: Callable[[], Callable | None], + wrapper_fn: Callable[[Callable], Callable] | None = None, +) -> Awaitable[BatchRequestOutput] | None: + """ + Generic handler for endpoint requests. + + Args: + request: The batch request input + tracker: Progress tracker for the batch + url_matcher: Function that takes a URL and returns True if it matches + handler_getter: Function that returns the handler function or None + wrapper_fn: Optional function to wrap the handler (e.g., for transcriptions) + + Returns: + Awaitable[BatchRequestOutput] if the request was handled, + None if URL didn't match + """ + if not url_matcher(request.url): + return None + + handler_fn = handler_getter() + if handler_fn is None: + error_msg = f"Model does not support endpoint: {request.url}" + return make_async_error_request_output(request, error_msg=error_msg) + + # Apply wrapper if provided (e.g., for transcriptions/translations) + if wrapper_fn is not None: + handler_fn = wrapper_fn(handler_fn) + + tracker.submitted() + return run_request(handler_fn, request, tracker) -async def run_batch( +def make_transcription_wrapper(is_translation: bool): + """ + Factory function to create a wrapper for transcription/translation handlers. + The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest + to TranscriptionRequest or TranslationRequest and calls the appropriate handler. + + Args: + is_translation: If True, process as translation; otherwise process + as transcription + + Returns: + A function that takes a handler and returns a wrapped handler + """ + + def wrapper(handler_fn: Callable): + async def transcription_wrapper( + batch_request_body: (BatchTranscriptionRequest | BatchTranslationRequest), + ) -> ( + TranscriptionResponse + | TranscriptionResponseVerbose + | TranslationResponse + | TranslationResponseVerbose + | ErrorResponse + ): + try: + # Download data from URL + audio_data = await download_bytes_from_url(batch_request_body.file_url) + + # Create a mock file from the downloaded audio data + mock_file = UploadFile( + file=BytesIO(audio_data), + filename="audio.bin", + ) + + # Convert batch request to regular request + # by copying all fields except file_url and setting file to mock_file + request_dict = batch_request_body.model_dump(exclude={"file_url"}) + request_dict["file"] = mock_file + + if is_translation: + # Create TranslationRequest from BatchTranslationRequest + translation_request = TranslationRequest.model_validate( + request_dict + ) + return await handler_fn(audio_data, translation_request) + else: + # Create TranscriptionRequest from BatchTranscriptionRequest + transcription_request = TranscriptionRequest.model_validate( + request_dict + ) + return await handler_fn(audio_data, transcription_request) + except Exception as e: + operation = "translation" if is_translation else "transcription" + return ErrorResponse( + error=ErrorInfo( + message=f"Failed to process {operation}: {str(e)}", + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST.value, + ) + ) + + return transcription_wrapper + + return wrapper + + +def build_endpoint_registry( engine_client: EngineClient, args: Namespace, -) -> None: - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] + base_model_paths: list[BaseModelPath], + request_logger: RequestLogger | None, + supported_tasks: tuple[SupportedTask, ...], +) -> dict[str, dict[str, Any]]: + """ + Build the endpoint registry with all serving objects and handler configurations. - if args.enable_log_requests: - request_logger = RequestLogger(max_log_len=args.max_log_len) - else: - request_logger = None - - base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) for name in served_model_names - ] + Args: + engine_client: The engine client + args: Command line arguments + base_model_paths: List of base model paths + request_logger: Optional request logger + supported_tasks: Tuple of supported tasks + Returns: + Dictionary mapping endpoint keys to their configurations + """ model_config = engine_client.model_config - supported_tasks = await engine_client.get_supported_tasks() - logger.info("Supported tasks: %s", supported_tasks) # Create the openai serving objects. openai_serving_models = OpenAIServingModels( @@ -507,6 +744,129 @@ async def run_batch( else None ) + openai_serving_transcription = ( + OpenAIServingTranscription( + engine_client, + openai_serving_models, + request_logger=request_logger, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) + + openai_serving_translation = ( + OpenAIServingTranslation( + engine_client, + openai_serving_models, + request_logger=request_logger, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "transcription" in supported_tasks + else None + ) + + # Registry of endpoint configurations + endpoint_registry: dict[str, dict[str, Any]] = { + "completions": { + "url_matcher": lambda url: url == "/v1/chat/completions", + "handler_getter": lambda: ( + openai_serving_chat.create_chat_completion + if openai_serving_chat is not None + else None + ), + "wrapper_fn": None, + }, + "embeddings": { + "url_matcher": lambda url: url == "/v1/embeddings", + "handler_getter": lambda: ( + openai_serving_embedding.create_embedding + if openai_serving_embedding is not None + else None + ), + "wrapper_fn": None, + }, + "score": { + "url_matcher": lambda url: url.endswith("/score"), + "handler_getter": lambda: ( + openai_serving_scores.create_score + if openai_serving_scores is not None + else None + ), + "wrapper_fn": None, + }, + "rerank": { + "url_matcher": lambda url: url.endswith("/rerank"), + "handler_getter": lambda: ( + openai_serving_scores.do_rerank + if openai_serving_scores is not None + else None + ), + "wrapper_fn": None, + }, + "transcriptions": { + "url_matcher": lambda url: url == "/v1/audio/transcriptions", + "handler_getter": lambda: ( + openai_serving_transcription.create_transcription + if openai_serving_transcription is not None + else None + ), + "wrapper_fn": make_transcription_wrapper(is_translation=False), + }, + "translations": { + "url_matcher": lambda url: url == "/v1/audio/translations", + "handler_getter": lambda: ( + openai_serving_translation.create_translation + if openai_serving_translation is not None + else None + ), + "wrapper_fn": make_transcription_wrapper(is_translation=True), + }, + } + + return endpoint_registry + + +def validate_run_batch_args(args): + valid_reasoning_parsers = ReasoningParserManager.list_registered() + if ( + reasoning_parser := args.structured_outputs_config.reasoning_parser + ) and reasoning_parser not in valid_reasoning_parsers: + raise KeyError( + f"invalid reasoning parser: {reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parsers)} }})" + ) + + +async def run_batch( + engine_client: EngineClient, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.enable_log_requests: + request_logger = RequestLogger(max_log_len=args.max_log_len) + else: + request_logger = None + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) for name in served_model_names + ] + + supported_tasks = await engine_client.get_supported_tasks() + logger.info("Supported tasks: %s", supported_tasks) + + endpoint_registry = build_endpoint_registry( + engine_client=engine_client, + args=args, + base_model_paths=base_model_paths, + request_logger=request_logger, + supported_tasks=supported_tasks, + ) + tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) @@ -520,84 +880,32 @@ async def run_batch( request = BatchRequestInput.model_validate_json(request_json) - # Determine the type of request and run it. - if request.url == "/v1/chat/completions": - chat_handler_fn = ( - openai_serving_chat.create_chat_completion - if openai_serving_chat is not None - else None - ) - if chat_handler_fn is None: - response_futures.append( - make_async_error_request_output( - request, - error_msg="The model does not support Chat Completions API", - ) - ) - continue + # Use the last segment of the URL as the endpoint key. + # More advanced URL matching is done in url_matcher of endpoint_registry. + endpoint_key = request.url.split("/")[-1] - response_futures.append(run_request(chat_handler_fn, request, tracker)) - tracker.submitted() - elif request.url == "/v1/embeddings": - embed_handler_fn = ( - openai_serving_embedding.create_embedding - if openai_serving_embedding is not None - else None + result = None + if endpoint_key in endpoint_registry: + endpoint_config = endpoint_registry[endpoint_key] + result = handle_endpoint_request( + request, + tracker, + url_matcher=endpoint_config["url_matcher"], + handler_getter=endpoint_config["handler_getter"], + wrapper_fn=endpoint_config["wrapper_fn"], ) - if embed_handler_fn is None: - response_futures.append( - make_async_error_request_output( - request, - error_msg="The model does not support Embeddings API", - ) - ) - continue - response_futures.append(run_request(embed_handler_fn, request, tracker)) - tracker.submitted() - elif request.url.endswith("/score"): - score_handler_fn = ( - openai_serving_scores.create_score - if openai_serving_scores is not None - else None - ) - if score_handler_fn is None: - response_futures.append( - make_async_error_request_output( - request, - error_msg="The model does not support Scores API", - ) - ) - continue - - response_futures.append(run_request(score_handler_fn, request, tracker)) - tracker.submitted() - elif request.url.endswith("/rerank"): - rerank_handler_fn = ( - openai_serving_scores.do_rerank - if openai_serving_scores is not None - else None - ) - if rerank_handler_fn is None: - response_futures.append( - make_async_error_request_output( - request, - error_msg="The model does not support Rerank API", - ) - ) - continue - - response_futures.append(run_request(rerank_handler_fn, request, tracker)) - tracker.submitted() + if result is not None: + response_futures.append(result) else: response_futures.append( make_async_error_request_output( request, error_msg=f"URL {request.url} was used. " "Supported endpoints: /v1/chat/completions, /v1/embeddings," - " /score, /rerank ." - "See vllm/entrypoints/openai/api_server.py for supported " - "score/rerank versions.", + " /v1/audio/transcriptions, /v1/audio/translations, /score, " + " /rerank. See vllm/entrypoints/openai/api_server.py " + "for supported score/rerank versions.", ) ) diff --git a/vllm/entrypoints/openai/speech_to_text/serving.py b/vllm/entrypoints/openai/speech_to_text/serving.py index 9d18f5aa3..b5ce17d0e 100644 --- a/vllm/entrypoints/openai/speech_to_text/serving.py +++ b/vllm/entrypoints/openai/speech_to_text/serving.py @@ -54,7 +54,10 @@ class OpenAIServingTranscription(OpenAISpeechToText): ) async def create_transcription( - self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request + self, + audio_data: bytes, + request: TranscriptionRequest, + raw_request: Request | None = None, ) -> ( TranscriptionResponse | TranscriptionResponseVerbose @@ -124,7 +127,10 @@ class OpenAIServingTranslation(OpenAISpeechToText): ) async def create_translation( - self, audio_data: bytes, request: TranslationRequest, raw_request: Request + self, + audio_data: bytes, + request: TranslationRequest, + raw_request: Request | None = None, ) -> ( TranslationResponse | TranslationResponseVerbose