[Refactor] [9/N] to simplify the vLLM openai translations serving ar chitecture (#32313)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -17,12 +17,12 @@ from argparse import Namespace
|
|||||||
from collections.abc import AsyncIterator, Awaitable
|
from collections.abc import AsyncIterator, Awaitable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated, Any
|
from typing import Any
|
||||||
|
|
||||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
import pydantic
|
import pydantic
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@@ -49,10 +49,6 @@ from vllm.entrypoints.openai.engine.protocol import (
|
|||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
ErrorInfo,
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
TranscriptionRequest,
|
|
||||||
TranscriptionResponseVariant,
|
|
||||||
TranslationRequest,
|
|
||||||
TranslationResponseVariant,
|
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||||
from vllm.entrypoints.openai.orca_metrics import metrics_header
|
from vllm.entrypoints.openai.orca_metrics import metrics_header
|
||||||
@@ -62,7 +58,7 @@ from vllm.entrypoints.openai.serving_models import (
|
|||||||
BaseModelPath,
|
BaseModelPath,
|
||||||
OpenAIServingModels,
|
OpenAIServingModels,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_transcription import (
|
from vllm.entrypoints.openai.translations.serving import (
|
||||||
OpenAIServingTranscription,
|
OpenAIServingTranscription,
|
||||||
OpenAIServingTranslation,
|
OpenAIServingTranslation,
|
||||||
)
|
)
|
||||||
@@ -239,10 +235,6 @@ def models(request: Request) -> OpenAIServingModels:
|
|||||||
return request.app.state.openai_serving_models
|
return request.app.state.openai_serving_models
|
||||||
|
|
||||||
|
|
||||||
def responses(request: Request) -> OpenAIServingResponses | None:
|
|
||||||
return request.app.state.openai_serving_responses
|
|
||||||
|
|
||||||
|
|
||||||
def messages(request: Request) -> AnthropicServingMessages:
|
def messages(request: Request) -> AnthropicServingMessages:
|
||||||
return request.app.state.anthropic_serving_messages
|
return request.app.state.anthropic_serving_messages
|
||||||
|
|
||||||
@@ -259,22 +251,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
|
|||||||
return request.app.state.openai_serving_tokenization
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
def transcription(request: Request) -> OpenAIServingTranscription:
|
|
||||||
return request.app.state.openai_serving_transcription
|
|
||||||
|
|
||||||
|
|
||||||
def translation(request: Request) -> OpenAIServingTranslation:
|
|
||||||
return request.app.state.openai_serving_translation
|
|
||||||
|
|
||||||
|
|
||||||
def engine_client(request: Request) -> EngineClient:
|
def engine_client(request: Request) -> EngineClient:
|
||||||
return request.app.state.engine_client
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
def generate_tokens(request: Request) -> ServingTokens | None:
|
|
||||||
return request.app.state.serving_tokens
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/load")
|
@router.get("/load")
|
||||||
async def get_server_load_metrics(request: Request):
|
async def get_server_load_metrics(request: Request):
|
||||||
# This endpoint returns the current server load metrics.
|
# This endpoint returns the current server load metrics.
|
||||||
@@ -410,84 +390,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/audio/transcriptions",
|
|
||||||
responses={
|
|
||||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def create_transcriptions(
|
|
||||||
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
|
|
||||||
):
|
|
||||||
handler = transcription(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support Transcriptions API"
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_data = await request.file.read()
|
|
||||||
try:
|
|
||||||
generator = await handler.create_transcription(audio_data, request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(generator, TranscriptionResponseVariant):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/audio/translations",
|
|
||||||
responses={
|
|
||||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@with_cancellation
|
|
||||||
@load_aware_call
|
|
||||||
async def create_translations(
|
|
||||||
request: Annotated[TranslationRequest, Form()], raw_request: Request
|
|
||||||
):
|
|
||||||
handler = translation(raw_request)
|
|
||||||
if handler is None:
|
|
||||||
return base(raw_request).create_error_response(
|
|
||||||
message="The model does not support Translations API"
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_data = await request.file.read()
|
|
||||||
try:
|
|
||||||
generator = await handler.create_translation(audio_data, request, raw_request)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=generator.model_dump(), status_code=generator.error.code
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(generator, TranslationResponseVariant):
|
|
||||||
return JSONResponse(content=generator.model_dump())
|
|
||||||
|
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
|
||||||
|
|
||||||
|
|
||||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||||
if not log_config_file:
|
if not log_config_file:
|
||||||
return None
|
return None
|
||||||
@@ -741,6 +643,11 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
register_responses_api_router(app)
|
register_responses_api_router(app)
|
||||||
|
from vllm.entrypoints.openai.translations.api_router import (
|
||||||
|
attach_router as register_translations_api_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_translations_api_router(app)
|
||||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||||
|
|
||||||
register_sagemaker_routes(router)
|
register_sagemaker_routes(router)
|
||||||
|
|||||||
@@ -5,12 +5,10 @@
|
|||||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
|
from typing import Annotated, Any, ClassVar, Literal, TypeAlias
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from fastapi import HTTPException, UploadFile
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
@@ -702,522 +700,6 @@ class DeltaMessage(OpenAIBaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
|
||||||
delta: DeltaMessage
|
|
||||||
finish_reason: str | None = None
|
|
||||||
stop_reason: int | str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionStreamResponse(OpenAIBaseModel):
|
|
||||||
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
|
|
||||||
object: Literal["transcription.chunk"] = "transcription.chunk"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: list[TranscriptionResponseStreamChoice]
|
|
||||||
usage: UsageInfo | None = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
## Protocols for Audio
|
|
||||||
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionRequest(OpenAIBaseModel):
|
|
||||||
# Ordered by official OpenAI API documentation
|
|
||||||
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
|
||||||
|
|
||||||
file: UploadFile
|
|
||||||
"""
|
|
||||||
The audio file object (not file name) to transcribe, in one of these
|
|
||||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: str | None = None
|
|
||||||
"""ID of the model to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
language: str | None = None
|
|
||||||
"""The language of the input audio.
|
|
||||||
|
|
||||||
Supplying the input language in
|
|
||||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
|
||||||
will improve accuracy and latency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt: str = Field(default="")
|
|
||||||
"""An optional text to guide the model's style or continue a previous audio
|
|
||||||
segment.
|
|
||||||
|
|
||||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
|
||||||
should match the audio language.
|
|
||||||
"""
|
|
||||||
|
|
||||||
response_format: AudioResponseFormat = Field(default="json")
|
|
||||||
"""
|
|
||||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
|
||||||
`verbose_json`, or `vtt`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
|
||||||
|
|
||||||
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
|
||||||
alias="timestamp_granularities[]", default=[]
|
|
||||||
)
|
|
||||||
"""The timestamp granularities to populate for this transcription.
|
|
||||||
|
|
||||||
`response_format` must be set `verbose_json` to use timestamp granularities.
|
|
||||||
Either or both of these options are supported: `word`, or `segment`. Note:
|
|
||||||
There is no additional latency for segment timestamps, but generating word
|
|
||||||
timestamps incurs additional latency.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stream: bool | None = False
|
|
||||||
"""When set, it will enable output to be streamed in a similar fashion
|
|
||||||
as the Chat Completion endpoint.
|
|
||||||
"""
|
|
||||||
# --8<-- [start:transcription-extra-params]
|
|
||||||
# Flattened stream option to simplify form data.
|
|
||||||
stream_include_usage: bool | None = False
|
|
||||||
stream_continuous_usage_stats: bool | None = False
|
|
||||||
|
|
||||||
vllm_xargs: dict[str, str | int | float] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Additional request parameters with string or "
|
|
||||||
"numeric values, used by custom extensions."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# --8<-- [end:transcription-extra-params]
|
|
||||||
|
|
||||||
to_language: str | None = None
|
|
||||||
"""The language of the output audio we transcribe to.
|
|
||||||
|
|
||||||
Please note that this is not currently used by supported models at this
|
|
||||||
time, but it is a placeholder for future use, matching translation api.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# --8<-- [start:transcription-sampling-params]
|
|
||||||
temperature: float = Field(default=0.0)
|
|
||||||
"""The sampling temperature, between 0 and 1.
|
|
||||||
|
|
||||||
Higher values like 0.8 will make the output more random, while lower values
|
|
||||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
|
||||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
|
||||||
to automatically increase the temperature until certain thresholds are hit.
|
|
||||||
"""
|
|
||||||
|
|
||||||
top_p: float | None = None
|
|
||||||
"""Enables nucleus (top-p) sampling, where tokens are selected from the
|
|
||||||
smallest possible set whose cumulative probability exceeds `p`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
top_k: int | None = None
|
|
||||||
"""Limits sampling to the `k` most probable tokens at each step."""
|
|
||||||
|
|
||||||
min_p: float | None = None
|
|
||||||
"""Filters out tokens with a probability lower than `min_p`, ensuring a
|
|
||||||
minimum likelihood threshold during sampling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
|
||||||
"""The seed to use for sampling."""
|
|
||||||
|
|
||||||
frequency_penalty: float | None = 0.0
|
|
||||||
"""The frequency penalty to use for sampling."""
|
|
||||||
|
|
||||||
repetition_penalty: float | None = None
|
|
||||||
"""The repetition penalty to use for sampling."""
|
|
||||||
|
|
||||||
presence_penalty: float | None = 0.0
|
|
||||||
"""The presence penalty to use for sampling."""
|
|
||||||
|
|
||||||
max_completion_tokens: int | None = None
|
|
||||||
"""The maximum number of tokens to generate."""
|
|
||||||
# --8<-- [end:transcription-sampling-params]
|
|
||||||
|
|
||||||
# Default sampling parameters for transcription requests.
|
|
||||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
|
||||||
"repetition_penalty": 1.0,
|
|
||||||
"temperature": 1.0,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"top_k": 0,
|
|
||||||
"min_p": 0.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_sampling_params(
|
|
||||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
|
||||||
) -> SamplingParams:
|
|
||||||
max_tokens = default_max_tokens
|
|
||||||
|
|
||||||
if default_sampling_params is None:
|
|
||||||
default_sampling_params = {}
|
|
||||||
|
|
||||||
# Default parameters
|
|
||||||
if (temperature := self.temperature) is None:
|
|
||||||
temperature = default_sampling_params.get(
|
|
||||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
|
||||||
)
|
|
||||||
if (top_p := self.top_p) is None:
|
|
||||||
top_p = default_sampling_params.get(
|
|
||||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
|
||||||
)
|
|
||||||
if (top_k := self.top_k) is None:
|
|
||||||
top_k = default_sampling_params.get(
|
|
||||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
|
||||||
)
|
|
||||||
if (min_p := self.min_p) is None:
|
|
||||||
min_p = default_sampling_params.get(
|
|
||||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if (repetition_penalty := self.repetition_penalty) is None:
|
|
||||||
repetition_penalty = default_sampling_params.get(
|
|
||||||
"repetition_penalty",
|
|
||||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return SamplingParams.from_optional(
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
seed=self.seed,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
min_p=min_p,
|
|
||||||
frequency_penalty=self.frequency_penalty,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
presence_penalty=self.presence_penalty,
|
|
||||||
output_kind=RequestOutputKind.DELTA
|
|
||||||
if self.stream
|
|
||||||
else RequestOutputKind.FINAL_ONLY,
|
|
||||||
extra_args=self.vllm_xargs,
|
|
||||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
|
||||||
)
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_transcription_request(cls, data):
|
|
||||||
if isinstance(data.get("file"), str):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
||||||
detail="Expected 'file' to be a file-like object, not 'str'.",
|
|
||||||
)
|
|
||||||
|
|
||||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
|
||||||
stream = data.get("stream", False)
|
|
||||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
|
||||||
# Find which specific stream option was set
|
|
||||||
invalid_param = next(
|
|
||||||
(so for so in stream_opts if data.get(so, False)),
|
|
||||||
"stream_include_usage",
|
|
||||||
)
|
|
||||||
raise VLLMValidationError(
|
|
||||||
"Stream options can only be defined when `stream=True`.",
|
|
||||||
parameter=invalid_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# Transcription response objects
|
|
||||||
class TranscriptionUsageAudio(OpenAIBaseModel):
|
|
||||||
type: Literal["duration"] = "duration"
|
|
||||||
seconds: int
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionResponse(OpenAIBaseModel):
|
|
||||||
text: str
|
|
||||||
"""The transcribed text."""
|
|
||||||
usage: TranscriptionUsageAudio
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionWord(OpenAIBaseModel):
|
|
||||||
end: float
|
|
||||||
"""End time of the word in seconds."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
"""Start time of the word in seconds."""
|
|
||||||
|
|
||||||
word: str
|
|
||||||
"""The text content of the word."""
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionSegment(OpenAIBaseModel):
|
|
||||||
id: int
|
|
||||||
"""Unique identifier of the segment."""
|
|
||||||
|
|
||||||
avg_logprob: float | None = None
|
|
||||||
"""Average logprob of the segment.
|
|
||||||
|
|
||||||
If the value is lower than -1, consider the logprobs failed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
compression_ratio: float | None = None
|
|
||||||
"""Compression ratio of the segment.
|
|
||||||
|
|
||||||
If the value is greater than 2.4, consider the compression failed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
end: float
|
|
||||||
"""End time of the segment in seconds."""
|
|
||||||
|
|
||||||
no_speech_prob: float | None = None
|
|
||||||
"""Probability of no speech in the segment.
|
|
||||||
|
|
||||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
|
||||||
this segment silent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
seek: int
|
|
||||||
"""Seek offset of the segment."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
"""Start time of the segment in seconds."""
|
|
||||||
|
|
||||||
temperature: float
|
|
||||||
"""Temperature parameter used for generating the segment."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""Text content of the segment."""
|
|
||||||
|
|
||||||
tokens: list[int]
|
|
||||||
"""Array of token IDs for the text content."""
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionResponseVerbose(OpenAIBaseModel):
|
|
||||||
duration: str
|
|
||||||
"""The duration of the input audio."""
|
|
||||||
|
|
||||||
language: str
|
|
||||||
"""The language of the input audio."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""The transcribed text."""
|
|
||||||
|
|
||||||
segments: list[TranscriptionSegment] | None = None
|
|
||||||
"""Segments of the transcribed text and their corresponding details."""
|
|
||||||
|
|
||||||
words: list[TranscriptionWord] | None = None
|
|
||||||
"""Extracted words and their corresponding timestamps."""
|
|
||||||
|
|
||||||
|
|
||||||
TranscriptionResponseVariant: TypeAlias = (
|
|
||||||
TranscriptionResponse | TranscriptionResponseVerbose
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
|
||||||
delta: DeltaMessage
|
|
||||||
finish_reason: str | None = None
|
|
||||||
stop_reason: int | str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationStreamResponse(OpenAIBaseModel):
|
|
||||||
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
|
||||||
object: Literal["translation.chunk"] = "translation.chunk"
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: list[TranslationResponseStreamChoice]
|
|
||||||
usage: UsageInfo | None = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationRequest(OpenAIBaseModel):
|
|
||||||
# Ordered by official OpenAI API documentation
|
|
||||||
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
|
||||||
|
|
||||||
file: UploadFile
|
|
||||||
"""
|
|
||||||
The audio file object (not file name) to translate, in one of these
|
|
||||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: str | None = None
|
|
||||||
"""ID of the model to use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt: str = Field(default="")
|
|
||||||
"""An optional text to guide the model's style or continue a previous audio
|
|
||||||
segment.
|
|
||||||
|
|
||||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
|
||||||
should match the audio language.
|
|
||||||
"""
|
|
||||||
|
|
||||||
response_format: AudioResponseFormat = Field(default="json")
|
|
||||||
"""
|
|
||||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
|
||||||
`verbose_json`, or `vtt`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO support additional sampling parameters
|
|
||||||
# --8<-- [start:translation-sampling-params]
|
|
||||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
|
||||||
"""The seed to use for sampling."""
|
|
||||||
|
|
||||||
temperature: float = Field(default=0.0)
|
|
||||||
"""The sampling temperature, between 0 and 1.
|
|
||||||
|
|
||||||
Higher values like 0.8 will make the output more random, while lower values
|
|
||||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
|
||||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
|
||||||
to automatically increase the temperature until certain thresholds are hit.
|
|
||||||
"""
|
|
||||||
# --8<-- [end:translation-sampling-params]
|
|
||||||
|
|
||||||
# --8<-- [start:translation-extra-params]
|
|
||||||
language: str | None = None
|
|
||||||
"""The language of the input audio we translate from.
|
|
||||||
|
|
||||||
Supplying the input language in
|
|
||||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
|
||||||
will improve accuracy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
to_language: str | None = None
|
|
||||||
"""The language of the input audio we translate to.
|
|
||||||
|
|
||||||
Please note that this is not supported by all models, refer to the specific
|
|
||||||
model documentation for more details.
|
|
||||||
For instance, Whisper only supports `to_language=en`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stream: bool | None = False
|
|
||||||
"""Custom field not present in the original OpenAI definition. When set,
|
|
||||||
it will enable output to be streamed in a similar fashion as the Chat
|
|
||||||
Completion endpoint.
|
|
||||||
"""
|
|
||||||
# Flattened stream option to simplify form data.
|
|
||||||
stream_include_usage: bool | None = False
|
|
||||||
stream_continuous_usage_stats: bool | None = False
|
|
||||||
|
|
||||||
max_completion_tokens: int | None = None
|
|
||||||
"""The maximum number of tokens to generate."""
|
|
||||||
# --8<-- [end:translation-extra-params]
|
|
||||||
|
|
||||||
# Default sampling parameters for translation requests.
|
|
||||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
|
||||||
"temperature": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_sampling_params(
|
|
||||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
|
||||||
) -> SamplingParams:
|
|
||||||
max_tokens = default_max_tokens
|
|
||||||
|
|
||||||
if default_sampling_params is None:
|
|
||||||
default_sampling_params = {}
|
|
||||||
# Default parameters
|
|
||||||
if (temperature := self.temperature) is None:
|
|
||||||
temperature = default_sampling_params.get(
|
|
||||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return SamplingParams.from_optional(
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
seed=self.seed,
|
|
||||||
output_kind=RequestOutputKind.DELTA
|
|
||||||
if self.stream
|
|
||||||
else RequestOutputKind.FINAL_ONLY,
|
|
||||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
|
||||||
)
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_stream_options(cls, data):
|
|
||||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
|
||||||
stream = data.get("stream", False)
|
|
||||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
|
||||||
# Find which specific stream option was set
|
|
||||||
invalid_param = next(
|
|
||||||
(so for so in stream_opts if data.get(so, False)),
|
|
||||||
"stream_include_usage",
|
|
||||||
)
|
|
||||||
raise VLLMValidationError(
|
|
||||||
"Stream options can only be defined when `stream=True`.",
|
|
||||||
parameter=invalid_param,
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# Translation response objects
|
|
||||||
class TranslationResponse(OpenAIBaseModel):
|
|
||||||
text: str
|
|
||||||
"""The translated text."""
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationWord(OpenAIBaseModel):
|
|
||||||
end: float
|
|
||||||
"""End time of the word in seconds."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
"""Start time of the word in seconds."""
|
|
||||||
|
|
||||||
word: str
|
|
||||||
"""The text content of the word."""
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationSegment(OpenAIBaseModel):
|
|
||||||
id: int
|
|
||||||
"""Unique identifier of the segment."""
|
|
||||||
|
|
||||||
avg_logprob: float | None = None
|
|
||||||
"""Average logprob of the segment.
|
|
||||||
|
|
||||||
If the value is lower than -1, consider the logprobs failed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
compression_ratio: float | None = None
|
|
||||||
"""Compression ratio of the segment.
|
|
||||||
|
|
||||||
If the value is greater than 2.4, consider the compression failed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
end: float
|
|
||||||
"""End time of the segment in seconds."""
|
|
||||||
|
|
||||||
no_speech_prob: float | None = None
|
|
||||||
"""Probability of no speech in the segment.
|
|
||||||
|
|
||||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
|
||||||
this segment silent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
seek: int
|
|
||||||
"""Seek offset of the segment."""
|
|
||||||
|
|
||||||
start: float
|
|
||||||
"""Start time of the segment in seconds."""
|
|
||||||
|
|
||||||
temperature: float
|
|
||||||
"""Temperature parameter used for generating the segment."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""Text content of the segment."""
|
|
||||||
|
|
||||||
tokens: list[int]
|
|
||||||
"""Array of token IDs for the text content."""
|
|
||||||
|
|
||||||
|
|
||||||
class TranslationResponseVerbose(OpenAIBaseModel):
|
|
||||||
duration: str
|
|
||||||
"""The duration of the input audio."""
|
|
||||||
|
|
||||||
language: str
|
|
||||||
"""The language of the input audio."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""The translated text."""
|
|
||||||
|
|
||||||
segments: list[TranslationSegment] | None = None
|
|
||||||
"""Segments of the translated text and their corresponding details."""
|
|
||||||
|
|
||||||
words: list[TranslationWord] | None = None
|
|
||||||
"""Extracted words and their corresponding timestamps."""
|
|
||||||
|
|
||||||
|
|
||||||
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
|
||||||
|
|
||||||
|
|
||||||
####### Tokens IN <> Tokens OUT #######
|
####### Tokens IN <> Tokens OUT #######
|
||||||
class GenerateRequest(BaseModel):
|
class GenerateRequest(BaseModel):
|
||||||
request_id: str = Field(
|
request_id: str = Field(
|
||||||
|
|||||||
@@ -50,9 +50,6 @@ from vllm.entrypoints.openai.engine.protocol import (
|
|||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
TranscriptionRequest,
|
|
||||||
TranscriptionResponse,
|
|
||||||
TranslationRequest,
|
|
||||||
VLLMValidationError,
|
VLLMValidationError,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.responses.protocol import (
|
from vllm.entrypoints.openai.responses.protocol import (
|
||||||
@@ -60,6 +57,11 @@ from vllm.entrypoints.openai.responses.protocol import (
|
|||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.openai.translations.protocol import (
|
||||||
|
TranscriptionRequest,
|
||||||
|
TranscriptionResponse,
|
||||||
|
TranslationRequest,
|
||||||
|
)
|
||||||
from vllm.entrypoints.pooling.classify.protocol import (
|
from vllm.entrypoints.pooling.classify.protocol import (
|
||||||
ClassificationChatRequest,
|
ClassificationChatRequest,
|
||||||
ClassificationCompletionRequest,
|
ClassificationCompletionRequest,
|
||||||
|
|||||||
2
vllm/entrypoints/openai/translations/__init__.py
Normal file
2
vllm/entrypoints/openai/translations/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
122
vllm/entrypoints/openai/translations/api_router.py
Normal file
122
vllm/entrypoints/openai/translations/api_router.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, FastAPI, Form, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.translations.protocol import (
|
||||||
|
TranscriptionRequest,
|
||||||
|
TranscriptionResponseVariant,
|
||||||
|
TranslationRequest,
|
||||||
|
TranslationResponseVariant,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.translations.serving import (
|
||||||
|
OpenAIServingTranscription,
|
||||||
|
OpenAIServingTranslation,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.utils import (
|
||||||
|
load_aware_call,
|
||||||
|
with_cancellation,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||||
|
return request.app.state.openai_serving_transcription
|
||||||
|
|
||||||
|
|
||||||
|
def translation(request: Request) -> OpenAIServingTranslation:
|
||||||
|
return request.app.state.openai_serving_translation
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/audio/transcriptions",
|
||||||
|
responses={
|
||||||
|
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_transcriptions(
|
||||||
|
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
|
||||||
|
):
|
||||||
|
handler = transcription(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
base_server = raw_request.app.state.openai_serving_tokenization
|
||||||
|
return base_server.create_error_response(
|
||||||
|
message="The model does not support Transcriptions API"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_data = await request.file.read()
|
||||||
|
try:
|
||||||
|
generator = await handler.create_transcription(audio_data, request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(generator, TranscriptionResponseVariant):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/audio/translations",
|
||||||
|
responses={
|
||||||
|
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_translations(
|
||||||
|
request: Annotated[TranslationRequest, Form()], raw_request: Request
|
||||||
|
):
|
||||||
|
handler = translation(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
base_server = raw_request.app.state.openai_serving_tokenization
|
||||||
|
return base_server.create_error_response(
|
||||||
|
message="The model does not support Translations API"
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_data = await request.file.read()
|
||||||
|
try:
|
||||||
|
generator = await handler.create_translation(audio_data, request, raw_request)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=generator.model_dump(), status_code=generator.error.code
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(generator, TranslationResponseVariant):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
def attach_router(app: FastAPI):
|
||||||
|
app.include_router(router)
|
||||||
545
vllm/entrypoints/openai/translations/protocol.py
Normal file
545
vllm/entrypoints/openai/translations/protocol.py
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import time
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Literal, TypeAlias
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from fastapi import HTTPException, UploadFile
|
||||||
|
from pydantic import (
|
||||||
|
Field,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.engine.protocol import (
|
||||||
|
DeltaMessage,
|
||||||
|
OpenAIBaseModel,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from vllm.exceptions import VLLMValidationError
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sampling_params import (
|
||||||
|
RequestOutputKind,
|
||||||
|
SamplingParams,
|
||||||
|
)
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
_LONG_INFO = torch.iinfo(torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: str | None = None
|
||||||
|
stop_reason: int | str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
|
||||||
|
object: Literal["transcription.chunk"] = "transcription.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: list[TranscriptionResponseStreamChoice]
|
||||||
|
usage: UsageInfo | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
## Protocols for Audio
|
||||||
|
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||||
|
|
||||||
|
file: UploadFile
|
||||||
|
"""
|
||||||
|
The audio file object (not file name) to transcribe, in one of these
|
||||||
|
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str | None = None
|
||||||
|
"""ID of the model to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
language: str | None = None
|
||||||
|
"""The language of the input audio.
|
||||||
|
|
||||||
|
Supplying the input language in
|
||||||
|
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||||
|
will improve accuracy and latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str = Field(default="")
|
||||||
|
"""An optional text to guide the model's style or continue a previous audio
|
||||||
|
segment.
|
||||||
|
|
||||||
|
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||||
|
should match the audio language.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_format: AudioResponseFormat = Field(default="json")
|
||||||
|
"""
|
||||||
|
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||||
|
`verbose_json`, or `vtt`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
||||||
|
|
||||||
|
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
||||||
|
alias="timestamp_granularities[]", default=[]
|
||||||
|
)
|
||||||
|
"""The timestamp granularities to populate for this transcription.
|
||||||
|
|
||||||
|
`response_format` must be set `verbose_json` to use timestamp granularities.
|
||||||
|
Either or both of these options are supported: `word`, or `segment`. Note:
|
||||||
|
There is no additional latency for segment timestamps, but generating word
|
||||||
|
timestamps incurs additional latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stream: bool | None = False
|
||||||
|
"""When set, it will enable output to be streamed in a similar fashion
|
||||||
|
as the Chat Completion endpoint.
|
||||||
|
"""
|
||||||
|
# --8<-- [start:transcription-extra-params]
|
||||||
|
# Flattened stream option to simplify form data.
|
||||||
|
stream_include_usage: bool | None = False
|
||||||
|
stream_continuous_usage_stats: bool | None = False
|
||||||
|
|
||||||
|
vllm_xargs: dict[str, str | int | float] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Additional request parameters with string or "
|
||||||
|
"numeric values, used by custom extensions."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# --8<-- [end:transcription-extra-params]
|
||||||
|
|
||||||
|
to_language: str | None = None
|
||||||
|
"""The language of the output audio we transcribe to.
|
||||||
|
|
||||||
|
Please note that this is not currently used by supported models at this
|
||||||
|
time, but it is a placeholder for future use, matching translation api.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --8<-- [start:transcription-sampling-params]
|
||||||
|
temperature: float = Field(default=0.0)
|
||||||
|
"""The sampling temperature, between 0 and 1.
|
||||||
|
|
||||||
|
Higher values like 0.8 will make the output more random, while lower values
|
||||||
|
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||||
|
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||||
|
to automatically increase the temperature until certain thresholds are hit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_p: float | None = None
|
||||||
|
"""Enables nucleus (top-p) sampling, where tokens are selected from the
|
||||||
|
smallest possible set whose cumulative probability exceeds `p`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_k: int | None = None
|
||||||
|
"""Limits sampling to the `k` most probable tokens at each step."""
|
||||||
|
|
||||||
|
min_p: float | None = None
|
||||||
|
"""Filters out tokens with a probability lower than `min_p`, ensuring a
|
||||||
|
minimum likelihood threshold during sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||||
|
"""The seed to use for sampling."""
|
||||||
|
|
||||||
|
frequency_penalty: float | None = 0.0
|
||||||
|
"""The frequency penalty to use for sampling."""
|
||||||
|
|
||||||
|
repetition_penalty: float | None = None
|
||||||
|
"""The repetition penalty to use for sampling."""
|
||||||
|
|
||||||
|
presence_penalty: float | None = 0.0
|
||||||
|
"""The presence penalty to use for sampling."""
|
||||||
|
|
||||||
|
max_completion_tokens: int | None = None
|
||||||
|
"""The maximum number of tokens to generate."""
|
||||||
|
# --8<-- [end:transcription-sampling-params]
|
||||||
|
|
||||||
|
# Default sampling parameters for transcription requests.
|
||||||
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"top_k": 0,
|
||||||
|
"min_p": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_sampling_params(
|
||||||
|
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||||
|
) -> SamplingParams:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
|
|
||||||
|
# Default parameters
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||||
|
)
|
||||||
|
if (top_p := self.top_p) is None:
|
||||||
|
top_p = default_sampling_params.get(
|
||||||
|
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||||
|
)
|
||||||
|
if (top_k := self.top_k) is None:
|
||||||
|
top_k = default_sampling_params.get(
|
||||||
|
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||||
|
)
|
||||||
|
if (min_p := self.min_p) is None:
|
||||||
|
min_p = default_sampling_params.get(
|
||||||
|
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if (repetition_penalty := self.repetition_penalty) is None:
|
||||||
|
repetition_penalty = default_sampling_params.get(
|
||||||
|
"repetition_penalty",
|
||||||
|
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return SamplingParams.from_optional(
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
seed=self.seed,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
output_kind=RequestOutputKind.DELTA
|
||||||
|
if self.stream
|
||||||
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
|
extra_args=self.vllm_xargs,
|
||||||
|
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_transcription_request(cls, data):
|
||||||
|
if isinstance(data.get("file"), str):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
||||||
|
detail="Expected 'file' to be a file-like object, not 'str'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||||
|
stream = data.get("stream", False)
|
||||||
|
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||||
|
# Find which specific stream option was set
|
||||||
|
invalid_param = next(
|
||||||
|
(so for so in stream_opts if data.get(so, False)),
|
||||||
|
"stream_include_usage",
|
||||||
|
)
|
||||||
|
raise VLLMValidationError(
|
||||||
|
"Stream options can only be defined when `stream=True`.",
|
||||||
|
parameter=invalid_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# Transcription response objects
|
||||||
|
class TranscriptionUsageAudio(OpenAIBaseModel):
|
||||||
|
type: Literal["duration"] = "duration"
|
||||||
|
seconds: int
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResponse(OpenAIBaseModel):
|
||||||
|
text: str
|
||||||
|
"""The transcribed text."""
|
||||||
|
usage: TranscriptionUsageAudio
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionWord(OpenAIBaseModel):
|
||||||
|
end: float
|
||||||
|
"""End time of the word in seconds."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the word in seconds."""
|
||||||
|
|
||||||
|
word: str
|
||||||
|
"""The text content of the word."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionSegment(OpenAIBaseModel):
|
||||||
|
id: int
|
||||||
|
"""Unique identifier of the segment."""
|
||||||
|
|
||||||
|
avg_logprob: float | None = None
|
||||||
|
"""Average logprob of the segment.
|
||||||
|
|
||||||
|
If the value is lower than -1, consider the logprobs failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
compression_ratio: float | None = None
|
||||||
|
"""Compression ratio of the segment.
|
||||||
|
|
||||||
|
If the value is greater than 2.4, consider the compression failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
end: float
|
||||||
|
"""End time of the segment in seconds."""
|
||||||
|
|
||||||
|
no_speech_prob: float | None = None
|
||||||
|
"""Probability of no speech in the segment.
|
||||||
|
|
||||||
|
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||||
|
this segment silent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
seek: int
|
||||||
|
"""Seek offset of the segment."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the segment in seconds."""
|
||||||
|
|
||||||
|
temperature: float
|
||||||
|
"""Temperature parameter used for generating the segment."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Text content of the segment."""
|
||||||
|
|
||||||
|
tokens: list[int]
|
||||||
|
"""Array of token IDs for the text content."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||||
|
duration: str
|
||||||
|
"""The duration of the input audio."""
|
||||||
|
|
||||||
|
language: str
|
||||||
|
"""The language of the input audio."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""The transcribed text."""
|
||||||
|
|
||||||
|
segments: list[TranscriptionSegment] | None = None
|
||||||
|
"""Segments of the transcribed text and their corresponding details."""
|
||||||
|
|
||||||
|
words: list[TranscriptionWord] | None = None
|
||||||
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|
||||||
|
|
||||||
|
TranscriptionResponseVariant: TypeAlias = (
|
||||||
|
TranscriptionResponse | TranscriptionResponseVerbose
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: str | None = None
|
||||||
|
stop_reason: int | str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||||
|
object: Literal["translation.chunk"] = "translation.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: list[TranslationResponseStreamChoice]
|
||||||
|
usage: UsageInfo | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||||
|
|
||||||
|
file: UploadFile
|
||||||
|
"""
|
||||||
|
The audio file object (not file name) to translate, in one of these
|
||||||
|
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str | None = None
|
||||||
|
"""ID of the model to use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str = Field(default="")
|
||||||
|
"""An optional text to guide the model's style or continue a previous audio
|
||||||
|
segment.
|
||||||
|
|
||||||
|
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||||
|
should match the audio language.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_format: AudioResponseFormat = Field(default="json")
|
||||||
|
"""
|
||||||
|
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||||
|
`verbose_json`, or `vtt`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO support additional sampling parameters
|
||||||
|
# --8<-- [start:translation-sampling-params]
|
||||||
|
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||||
|
"""The seed to use for sampling."""
|
||||||
|
|
||||||
|
temperature: float = Field(default=0.0)
|
||||||
|
"""The sampling temperature, between 0 and 1.
|
||||||
|
|
||||||
|
Higher values like 0.8 will make the output more random, while lower values
|
||||||
|
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||||
|
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||||
|
to automatically increase the temperature until certain thresholds are hit.
|
||||||
|
"""
|
||||||
|
# --8<-- [end:translation-sampling-params]
|
||||||
|
|
||||||
|
# --8<-- [start:translation-extra-params]
|
||||||
|
language: str | None = None
|
||||||
|
"""The language of the input audio we translate from.
|
||||||
|
|
||||||
|
Supplying the input language in
|
||||||
|
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||||
|
will improve accuracy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
to_language: str | None = None
|
||||||
|
"""The language of the input audio we translate to.
|
||||||
|
|
||||||
|
Please note that this is not supported by all models, refer to the specific
|
||||||
|
model documentation for more details.
|
||||||
|
For instance, Whisper only supports `to_language=en`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stream: bool | None = False
|
||||||
|
"""Custom field not present in the original OpenAI definition. When set,
|
||||||
|
it will enable output to be streamed in a similar fashion as the Chat
|
||||||
|
Completion endpoint.
|
||||||
|
"""
|
||||||
|
# Flattened stream option to simplify form data.
|
||||||
|
stream_include_usage: bool | None = False
|
||||||
|
stream_continuous_usage_stats: bool | None = False
|
||||||
|
|
||||||
|
max_completion_tokens: int | None = None
|
||||||
|
"""The maximum number of tokens to generate."""
|
||||||
|
# --8<-- [end:translation-extra-params]
|
||||||
|
|
||||||
|
# Default sampling parameters for translation requests.
|
||||||
|
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||||
|
"temperature": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_sampling_params(
|
||||||
|
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||||
|
) -> SamplingParams:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
if default_sampling_params is None:
|
||||||
|
default_sampling_params = {}
|
||||||
|
# Default parameters
|
||||||
|
if (temperature := self.temperature) is None:
|
||||||
|
temperature = default_sampling_params.get(
|
||||||
|
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return SamplingParams.from_optional(
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
seed=self.seed,
|
||||||
|
output_kind=RequestOutputKind.DELTA
|
||||||
|
if self.stream
|
||||||
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
|
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_stream_options(cls, data):
|
||||||
|
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||||
|
stream = data.get("stream", False)
|
||||||
|
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||||
|
# Find which specific stream option was set
|
||||||
|
invalid_param = next(
|
||||||
|
(so for so in stream_opts if data.get(so, False)),
|
||||||
|
"stream_include_usage",
|
||||||
|
)
|
||||||
|
raise VLLMValidationError(
|
||||||
|
"Stream options can only be defined when `stream=True`.",
|
||||||
|
parameter=invalid_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# Translation response objects
|
||||||
|
class TranslationResponse(OpenAIBaseModel):
|
||||||
|
text: str
|
||||||
|
"""The translated text."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationWord(OpenAIBaseModel):
|
||||||
|
end: float
|
||||||
|
"""End time of the word in seconds."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the word in seconds."""
|
||||||
|
|
||||||
|
word: str
|
||||||
|
"""The text content of the word."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationSegment(OpenAIBaseModel):
|
||||||
|
id: int
|
||||||
|
"""Unique identifier of the segment."""
|
||||||
|
|
||||||
|
avg_logprob: float | None = None
|
||||||
|
"""Average logprob of the segment.
|
||||||
|
|
||||||
|
If the value is lower than -1, consider the logprobs failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
compression_ratio: float | None = None
|
||||||
|
"""Compression ratio of the segment.
|
||||||
|
|
||||||
|
If the value is greater than 2.4, consider the compression failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
end: float
|
||||||
|
"""End time of the segment in seconds."""
|
||||||
|
|
||||||
|
no_speech_prob: float | None = None
|
||||||
|
"""Probability of no speech in the segment.
|
||||||
|
|
||||||
|
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||||
|
this segment silent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
seek: int
|
||||||
|
"""Seek offset of the segment."""
|
||||||
|
|
||||||
|
start: float
|
||||||
|
"""Start time of the segment in seconds."""
|
||||||
|
|
||||||
|
temperature: float
|
||||||
|
"""Temperature parameter used for generating the segment."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Text content of the segment."""
|
||||||
|
|
||||||
|
tokens: list[int]
|
||||||
|
"""Array of token IDs for the text content."""
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||||
|
duration: str
|
||||||
|
"""The duration of the input audio."""
|
||||||
|
|
||||||
|
language: str
|
||||||
|
"""The language of the input audio."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""The translated text."""
|
||||||
|
|
||||||
|
segments: list[TranslationSegment] | None = None
|
||||||
|
"""Segments of the translated text and their corresponding details."""
|
||||||
|
|
||||||
|
words: list[TranslationWord] | None = None
|
||||||
|
"""Extracted words and their corresponding timestamps."""
|
||||||
|
|
||||||
|
|
||||||
|
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
||||||
@@ -9,6 +9,9 @@ from vllm.entrypoints.logger import RequestLogger
|
|||||||
from vllm.entrypoints.openai.engine.protocol import (
|
from vllm.entrypoints.openai.engine.protocol import (
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.openai.translations.protocol import (
|
||||||
TranscriptionRequest,
|
TranscriptionRequest,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
TranscriptionResponseStreamChoice,
|
TranscriptionResponseStreamChoice,
|
||||||
@@ -20,8 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
|||||||
TranslationResponseVerbose,
|
TranslationResponseVerbose,
|
||||||
TranslationStreamResponse,
|
TranslationStreamResponse,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.translations.speech_to_text import OpenAISpeechToText
|
||||||
from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
@@ -19,6 +19,11 @@ from vllm.entrypoints.openai.engine.protocol import (
|
|||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.openai.translations.protocol import (
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
TranscriptionResponseStreamChoice,
|
TranscriptionResponseStreamChoice,
|
||||||
TranscriptionResponseVerbose,
|
TranscriptionResponseVerbose,
|
||||||
@@ -29,11 +34,8 @@ from vllm.entrypoints.openai.engine.protocol import (
|
|||||||
TranslationResponseVerbose,
|
TranslationResponseVerbose,
|
||||||
TranslationSegment,
|
TranslationSegment,
|
||||||
TranslationStreamResponse,
|
TranslationStreamResponse,
|
||||||
UsageInfo,
|
|
||||||
VLLMValidationError,
|
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest
|
from vllm.exceptions import VLLMValidationError
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models import SupportsTranscription, supports_transcription
|
from vllm.model_executor.models import SupportsTranscription, supports_transcription
|
||||||
Reference in New Issue
Block a user