[Frontend]Add support for transcriptions and translations to run_batch (#33934)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
@@ -42,6 +43,27 @@ INPUT_RERANK_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/re
|
||||
INPUT_REASONING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Solve this math problem: 2+2=?"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "What is the capital of France?"}]}}"""
|
||||
|
||||
# This is a valid but minimal audio file for testing
|
||||
MINIMAL_WAV_BASE64 = "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAZGF0YQAAAAA="
|
||||
INPUT_TRANSCRIPTION_BATCH = (
|
||||
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", '
|
||||
'"body": {{"model": "openai/whisper-large-v3", "file_url": "data:audio/wav;base64,{}", '
|
||||
'"response_format": "json"}}}}\n'
|
||||
).format(MINIMAL_WAV_BASE64)
|
||||
|
||||
INPUT_TRANSCRIPTION_HTTP_BATCH = (
|
||||
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/transcriptions", '
|
||||
'"body": {{"model": "openai/whisper-large-v3", "file_url": "{}", '
|
||||
'"response_format": "json"}}}}\n'
|
||||
).format(AudioAsset("mary_had_lamb").url)
|
||||
|
||||
INPUT_TRANSLATION_BATCH = (
|
||||
'{{"custom_id": "request-1", "method": "POST", "url": "/v1/audio/translations", '
|
||||
'"body": {{"model": "openai/whisper-small", "file_url": "{}", '
|
||||
'"response_format": "text", "language": "it", "to_language": "en", '
|
||||
'"temperature": 0.0}}}}\n'
|
||||
).format(AudioAsset("mary_had_lamb").url)
|
||||
|
||||
|
||||
def test_empty_file():
|
||||
with (
|
||||
@@ -238,3 +260,121 @@ def test_reasoning_parser():
|
||||
]
|
||||
assert reasoning is not None
|
||||
assert len(reasoning) > 0
|
||||
|
||||
|
||||
def test_transcription():
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(INPUT_TRANSCRIPTION_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"openai/whisper-large-v3",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
|
||||
contents = output_file.read()
|
||||
print(f"\n\ncontents: {contents}\n\n")
|
||||
for line in contents.strip().split("\n"):
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
line_dict = json.loads(line)
|
||||
assert isinstance(line_dict, dict)
|
||||
assert line_dict["error"] is None
|
||||
|
||||
response_body = line_dict["response"]["body"]
|
||||
assert response_body is not None
|
||||
assert "text" in response_body
|
||||
assert "usage" in response_body
|
||||
|
||||
|
||||
def test_transcription_http_url():
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(INPUT_TRANSCRIPTION_HTTP_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"openai/whisper-large-v3",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
|
||||
contents = output_file.read()
|
||||
for line in contents.strip().split("\n"):
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
line_dict = json.loads(line)
|
||||
assert isinstance(line_dict, dict)
|
||||
assert line_dict["error"] is None
|
||||
|
||||
response_body = line_dict["response"]["body"]
|
||||
assert response_body is not None
|
||||
assert "text" in response_body
|
||||
assert "usage" in response_body
|
||||
|
||||
transcription_text = response_body["text"]
|
||||
assert "Mary had a little lamb" in transcription_text
|
||||
|
||||
|
||||
def test_translation():
|
||||
with (
|
||||
tempfile.NamedTemporaryFile("w") as input_file,
|
||||
tempfile.NamedTemporaryFile("r") as output_file,
|
||||
):
|
||||
input_file.write(INPUT_TRANSLATION_BATCH)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"vllm",
|
||||
"run-batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"openai/whisper-small",
|
||||
],
|
||||
)
|
||||
proc.communicate()
|
||||
proc.wait()
|
||||
assert proc.returncode == 0, f"{proc=}"
|
||||
|
||||
contents = output_file.read()
|
||||
for line in contents.strip().split("\n"):
|
||||
BatchRequestOutput.model_validate_json(line)
|
||||
|
||||
line_dict = json.loads(line)
|
||||
assert isinstance(line_dict, dict)
|
||||
assert line_dict["error"] is None
|
||||
|
||||
response_body = line_dict["response"]["body"]
|
||||
assert response_body is not None
|
||||
assert "text" in response_body
|
||||
|
||||
translation_text = response_body["text"]
|
||||
translation_text_lower = str(translation_text).strip().lower()
|
||||
assert "mary" in translation_text_lower or "lamb" in translation_text_lower
|
||||
|
||||
@@ -2,17 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any, TypeAlias
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from fastapi import UploadFile
|
||||
from prometheus_client import start_http_server
|
||||
from pydantic import TypeAdapter, field_validator
|
||||
from pydantic import Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -25,12 +28,28 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
OpenAIBaseModel,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest, EmbeddingResponse
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.serving import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
@@ -41,6 +60,7 @@ from vllm.entrypoints.pooling.score.protocol import (
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@@ -48,8 +68,73 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BatchTranscriptionRequest(TranscriptionRequest):
|
||||
"""
|
||||
Batch transcription request that uses file_url instead of file.
|
||||
|
||||
This class extends TranscriptionRequest but replaces the file field
|
||||
with file_url to support batch processing from audio files written in JSON format.
|
||||
"""
|
||||
|
||||
file_url: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Either a URL of the audio or a data URL with base64 encoded audio data. "
|
||||
),
|
||||
)
|
||||
|
||||
# Override file to be optional and unused for batch processing
|
||||
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_no_file(cls, data: Any):
|
||||
"""Ensure file field is not provided in batch requests."""
|
||||
if isinstance(data, dict) and "file" in data:
|
||||
raise ValueError(
|
||||
"The 'file' field is not supported in batch requests. "
|
||||
"Use 'file_url' instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class BatchTranslationRequest(TranslationRequest):
|
||||
"""
|
||||
Batch translation request that uses file_url instead of file.
|
||||
|
||||
This class extends TranslationRequest but replaces the file field
|
||||
with file_url to support batch processing from audio files written in JSON format.
|
||||
"""
|
||||
|
||||
file_url: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Either a URL of the audio or a data URL with base64 encoded audio data. "
|
||||
),
|
||||
)
|
||||
|
||||
# Override file to be optional and unused for batch processing
|
||||
file: UploadFile | None = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_no_file(cls, data: Any):
|
||||
"""Ensure file field is not provided in batch requests."""
|
||||
if isinstance(data, dict) and "file" in data:
|
||||
raise ValueError(
|
||||
"The 'file' field is not supported in batch requests. "
|
||||
"Use 'file_url' instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
BatchRequestInputBody: TypeAlias = (
|
||||
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
|
||||
ChatCompletionRequest
|
||||
| EmbeddingRequest
|
||||
| ScoreRequest
|
||||
| RerankRequest
|
||||
| BatchTranscriptionRequest
|
||||
| BatchTranslationRequest
|
||||
)
|
||||
|
||||
|
||||
@@ -88,6 +173,10 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
return TypeAdapter(ScoreRequest).validate_python(value)
|
||||
if url.endswith("/rerank"):
|
||||
return RerankRequest.model_validate(value)
|
||||
if url == "/v1/audio/transcriptions":
|
||||
return BatchTranscriptionRequest.model_validate(value)
|
||||
if url == "/v1/audio/translations":
|
||||
return BatchTranslationRequest.model_validate(value)
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
@@ -104,6 +193,10 @@ class BatchResponseData(OpenAIBaseModel):
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| None
|
||||
) = None
|
||||
|
||||
@@ -361,6 +454,49 @@ async def write_file(
|
||||
await write_local_file(path_or_url, batch_outputs)
|
||||
|
||||
|
||||
async def download_bytes_from_url(url: str) -> bytes:
|
||||
"""
|
||||
Download data from a URL or decode from a data URL.
|
||||
|
||||
Args:
|
||||
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
|
||||
|
||||
Returns:
|
||||
Data as bytes
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Handle data URLs (base64 encoded)
|
||||
if parsed.scheme == "data":
|
||||
# Format: data:...;base64,<base64_data>
|
||||
if "," in url:
|
||||
header, data = url.split(",", 1)
|
||||
if "base64" in header:
|
||||
return base64.b64decode(data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data URL encoding: {header}")
|
||||
else:
|
||||
raise ValueError(f"Invalid data URL format: {url}")
|
||||
|
||||
# Handle HTTP/HTTPS URLs
|
||||
elif parsed.scheme in ("http", "https"):
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(url) as resp,
|
||||
):
|
||||
if resp.status != 200:
|
||||
raise Exception(
|
||||
f"Failed to download data from URL: {url}. Status: {resp.status}"
|
||||
)
|
||||
return await resp.read()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL scheme: {parsed.scheme}. "
|
||||
"Supported schemes: http, https, data"
|
||||
)
|
||||
|
||||
|
||||
def make_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str
|
||||
) -> BatchRequestOutput:
|
||||
@@ -391,7 +527,16 @@ async def run_request(
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse),
|
||||
(
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
ScoreResponse,
|
||||
RerankResponse,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
),
|
||||
):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
@@ -420,38 +565,130 @@ async def run_request(
|
||||
return batch_output
|
||||
|
||||
|
||||
def validate_run_batch_args(args):
|
||||
valid_reasoning_parsers = ReasoningParserManager.list_registered()
|
||||
if (
|
||||
reasoning_parser := args.structured_outputs_config.reasoning_parser
|
||||
) and reasoning_parser not in valid_reasoning_parsers:
|
||||
raise KeyError(
|
||||
f"invalid reasoning parser: {reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
|
||||
)
|
||||
def handle_endpoint_request(
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker,
|
||||
url_matcher: Callable[[str], bool],
|
||||
handler_getter: Callable[[], Callable | None],
|
||||
wrapper_fn: Callable[[Callable], Callable] | None = None,
|
||||
) -> Awaitable[BatchRequestOutput] | None:
|
||||
"""
|
||||
Generic handler for endpoint requests.
|
||||
|
||||
Args:
|
||||
request: The batch request input
|
||||
tracker: Progress tracker for the batch
|
||||
url_matcher: Function that takes a URL and returns True if it matches
|
||||
handler_getter: Function that returns the handler function or None
|
||||
wrapper_fn: Optional function to wrap the handler (e.g., for transcriptions)
|
||||
|
||||
Returns:
|
||||
Awaitable[BatchRequestOutput] if the request was handled,
|
||||
None if URL didn't match
|
||||
"""
|
||||
if not url_matcher(request.url):
|
||||
return None
|
||||
|
||||
handler_fn = handler_getter()
|
||||
if handler_fn is None:
|
||||
error_msg = f"Model does not support endpoint: {request.url}"
|
||||
return make_async_error_request_output(request, error_msg=error_msg)
|
||||
|
||||
# Apply wrapper if provided (e.g., for transcriptions/translations)
|
||||
if wrapper_fn is not None:
|
||||
handler_fn = wrapper_fn(handler_fn)
|
||||
|
||||
tracker.submitted()
|
||||
return run_request(handler_fn, request, tracker)
|
||||
|
||||
|
||||
async def run_batch(
|
||||
def make_transcription_wrapper(is_translation: bool):
|
||||
"""
|
||||
Factory function to create a wrapper for transcription/translation handlers.
|
||||
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
|
||||
to TranscriptionRequest or TranslationRequest and calls the appropriate handler.
|
||||
|
||||
Args:
|
||||
is_translation: If True, process as translation; otherwise process
|
||||
as transcription
|
||||
|
||||
Returns:
|
||||
A function that takes a handler and returns a wrapped handler
|
||||
"""
|
||||
|
||||
def wrapper(handler_fn: Callable):
|
||||
async def transcription_wrapper(
|
||||
batch_request_body: (BatchTranscriptionRequest | BatchTranslationRequest),
|
||||
) -> (
|
||||
TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| ErrorResponse
|
||||
):
|
||||
try:
|
||||
# Download data from URL
|
||||
audio_data = await download_bytes_from_url(batch_request_body.file_url)
|
||||
|
||||
# Create a mock file from the downloaded audio data
|
||||
mock_file = UploadFile(
|
||||
file=BytesIO(audio_data),
|
||||
filename="audio.bin",
|
||||
)
|
||||
|
||||
# Convert batch request to regular request
|
||||
# by copying all fields except file_url and setting file to mock_file
|
||||
request_dict = batch_request_body.model_dump(exclude={"file_url"})
|
||||
request_dict["file"] = mock_file
|
||||
|
||||
if is_translation:
|
||||
# Create TranslationRequest from BatchTranslationRequest
|
||||
translation_request = TranslationRequest.model_validate(
|
||||
request_dict
|
||||
)
|
||||
return await handler_fn(audio_data, translation_request)
|
||||
else:
|
||||
# Create TranscriptionRequest from BatchTranscriptionRequest
|
||||
transcription_request = TranscriptionRequest.model_validate(
|
||||
request_dict
|
||||
)
|
||||
return await handler_fn(audio_data, transcription_request)
|
||||
except Exception as e:
|
||||
operation = "translation" if is_translation else "transcription"
|
||||
return ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=f"Failed to process {operation}: {str(e)}",
|
||||
type="BadRequestError",
|
||||
code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
)
|
||||
|
||||
return transcription_wrapper
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_endpoint_registry(
|
||||
engine_client: EngineClient,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
base_model_paths: list[BaseModelPath],
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple[SupportedTask, ...],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
Build the endpoint registry with all serving objects and handler configurations.
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
|
||||
]
|
||||
Args:
|
||||
engine_client: The engine client
|
||||
args: Command line arguments
|
||||
base_model_paths: List of base model paths
|
||||
request_logger: Optional request logger
|
||||
supported_tasks: Tuple of supported tasks
|
||||
|
||||
Returns:
|
||||
Dictionary mapping endpoint keys to their configurations
|
||||
"""
|
||||
model_config = engine_client.model_config
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
@@ -507,6 +744,129 @@ async def run_batch(
|
||||
else None
|
||||
)
|
||||
|
||||
openai_serving_transcription = (
|
||||
OpenAIServingTranscription(
|
||||
engine_client,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
openai_serving_translation = (
|
||||
OpenAIServingTranslation(
|
||||
engine_client,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
# Registry of endpoint configurations
|
||||
endpoint_registry: dict[str, dict[str, Any]] = {
|
||||
"completions": {
|
||||
"url_matcher": lambda url: url == "/v1/chat/completions",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_chat.create_chat_completion
|
||||
if openai_serving_chat is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"embeddings": {
|
||||
"url_matcher": lambda url: url == "/v1/embeddings",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_embedding.create_embedding
|
||||
if openai_serving_embedding is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"score": {
|
||||
"url_matcher": lambda url: url.endswith("/score"),
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_scores.create_score
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"rerank": {
|
||||
"url_matcher": lambda url: url.endswith("/rerank"),
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_scores.do_rerank
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"transcriptions": {
|
||||
"url_matcher": lambda url: url == "/v1/audio/transcriptions",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_transcription.create_transcription
|
||||
if openai_serving_transcription is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": make_transcription_wrapper(is_translation=False),
|
||||
},
|
||||
"translations": {
|
||||
"url_matcher": lambda url: url == "/v1/audio/translations",
|
||||
"handler_getter": lambda: (
|
||||
openai_serving_translation.create_translation
|
||||
if openai_serving_translation is not None
|
||||
else None
|
||||
),
|
||||
"wrapper_fn": make_transcription_wrapper(is_translation=True),
|
||||
},
|
||||
}
|
||||
|
||||
return endpoint_registry
|
||||
|
||||
|
||||
def validate_run_batch_args(args):
|
||||
valid_reasoning_parsers = ReasoningParserManager.list_registered()
|
||||
if (
|
||||
reasoning_parser := args.structured_outputs_config.reasoning_parser
|
||||
) and reasoning_parser not in valid_reasoning_parsers:
|
||||
raise KeyError(
|
||||
f"invalid reasoning parser: {reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
|
||||
)
|
||||
|
||||
|
||||
async def run_batch(
|
||||
engine_client: EngineClient,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
|
||||
]
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
endpoint_registry = build_endpoint_registry(
|
||||
engine_client=engine_client,
|
||||
args=args,
|
||||
base_model_paths=base_model_paths,
|
||||
request_logger=request_logger,
|
||||
supported_tasks=supported_tasks,
|
||||
)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
@@ -520,84 +880,32 @@ async def run_batch(
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
chat_handler_fn = (
|
||||
openai_serving_chat.create_chat_completion
|
||||
if openai_serving_chat is not None
|
||||
else None
|
||||
)
|
||||
if chat_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Chat Completions API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# Use the last segment of the URL as the endpoint key.
|
||||
# More advanced URL matching is done in url_matcher of endpoint_registry.
|
||||
endpoint_key = request.url.split("/")[-1]
|
||||
|
||||
response_futures.append(run_request(chat_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
embed_handler_fn = (
|
||||
openai_serving_embedding.create_embedding
|
||||
if openai_serving_embedding is not None
|
||||
else None
|
||||
result = None
|
||||
if endpoint_key in endpoint_registry:
|
||||
endpoint_config = endpoint_registry[endpoint_key]
|
||||
result = handle_endpoint_request(
|
||||
request,
|
||||
tracker,
|
||||
url_matcher=endpoint_config["url_matcher"],
|
||||
handler_getter=endpoint_config["handler_getter"],
|
||||
wrapper_fn=endpoint_config["wrapper_fn"],
|
||||
)
|
||||
if embed_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Embeddings API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/score"):
|
||||
score_handler_fn = (
|
||||
openai_serving_scores.create_score
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
)
|
||||
if score_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Scores API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url.endswith("/rerank"):
|
||||
rerank_handler_fn = (
|
||||
openai_serving_scores.do_rerank
|
||||
if openai_serving_scores is not None
|
||||
else None
|
||||
)
|
||||
if rerank_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="The model does not support Rerank API",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(rerank_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
if result is not None:
|
||||
response_futures.append(result)
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg=f"URL {request.url} was used. "
|
||||
"Supported endpoints: /v1/chat/completions, /v1/embeddings,"
|
||||
" /score, /rerank ."
|
||||
"See vllm/entrypoints/openai/api_server.py for supported "
|
||||
"score/rerank versions.",
|
||||
" /v1/audio/transcriptions, /v1/audio/translations, /score, "
|
||||
" /rerank. See vllm/entrypoints/openai/api_server.py "
|
||||
"for supported score/rerank versions.",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -54,7 +54,10 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
)
|
||||
|
||||
async def create_transcription(
|
||||
self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranscriptionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
@@ -124,7 +127,10 @@ class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
)
|
||||
|
||||
async def create_translation(
|
||||
self, audio_data: bytes, request: TranslationRequest, raw_request: Request
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranslationRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
|
||||
Reference in New Issue
Block a user