[Frontend] Add Support for MM Encoder/Decoder Beam Search (Online Transcriptions) (#36160)
Signed-off-by: Alex Brooks <albrooks@redhat.com>
This commit is contained in:
@@ -439,6 +439,8 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
|
||||
|
||||
Code example: [examples/online_serving/openai_transcription_client.py](../../examples/online_serving/openai_transcription_client.py)
|
||||
|
||||
NOTE: beam search is currently supported in the transcriptions endpoint for encoder-decoder multimodal models, e.g., whisper, but highly inefficient as work for handling the encoder/decoder cache is actively ongoing. This is an active point of ongoing optimization and will be handled properly in the very near future.
|
||||
|
||||
#### API Enforced Limits
|
||||
|
||||
Set the maximum audio file size (in MB) that VLLM will accept, via the
|
||||
|
||||
@@ -317,3 +317,72 @@ async def test_language_auto_detect(
|
||||
assert any(word.lower() in text_lower for word in expected_text), (
|
||||
f"Expected {expected_lang} text but got: {transcription.text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whisper_beam_search_single_beam(mary_had_lamb, whisper_client):
|
||||
"""Test beam search with encoder-decoder model (Whisper) on transcriptions with
|
||||
one beam aligns with greedy decoding.
|
||||
"""
|
||||
beam_transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
use_beam_search=True,
|
||||
n=1,
|
||||
),
|
||||
)
|
||||
|
||||
greedy_transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
greedy_res = json.loads(greedy_transcription)["text"]
|
||||
beam_res = json.loads(beam_transcription)["text"]
|
||||
assert greedy_res == beam_res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whisper_beam_search_multibeam(mary_had_lamb, whisper_client):
|
||||
"""Test n>1 for beam search returns one transcription (best beam)."""
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
use_beam_search=True,
|
||||
n=2,
|
||||
),
|
||||
)
|
||||
|
||||
result = json.loads(transcription)
|
||||
|
||||
text = result["text"]
|
||||
|
||||
assert text is not None
|
||||
assert len(text) > 0
|
||||
assert "mary had a little lamb" in text.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_beams_raises(winning_call, whisper_client):
|
||||
"""Test that stream=True + beam search raises bad request for now."""
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
stream=True,
|
||||
extra_body=dict(
|
||||
use_beam_search=True,
|
||||
n=2,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -129,6 +129,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
if request.stream and request.use_beam_search:
|
||||
return self.create_error_response(
|
||||
"Streaming is not currently supported with beam search"
|
||||
)
|
||||
|
||||
result = await self.render_completion_request(request)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
@@ -211,13 +216,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
model_name = self.models.model_name(lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# We do not stream the results when using beam search.
|
||||
stream = request.stream and not request.use_beam_search
|
||||
|
||||
# Streaming response
|
||||
tokenizer = self.renderer.tokenizer
|
||||
|
||||
if stream:
|
||||
if request.stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
engine_prompts,
|
||||
|
||||
@@ -237,13 +237,14 @@ class OpenAIServing:
|
||||
|
||||
if prompt["type"] == "embeds":
|
||||
raise NotImplementedError("Embedding prompt not supported for beam search")
|
||||
if prompt["type"] == "enc_dec":
|
||||
raise NotImplementedError(
|
||||
"Encoder-decoder prompt not supported for beam search"
|
||||
)
|
||||
|
||||
prompt_text = prompt.get("prompt")
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
# Extract prompt tokens and text based on model type
|
||||
decoder_prompt = (
|
||||
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
|
||||
)
|
||||
prompt_text = decoder_prompt.get("prompt")
|
||||
prompt_token_ids = decoder_prompt["prompt_token_ids"]
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
logprobs_num = 2 * beam_width
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
)
|
||||
@@ -123,6 +124,18 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
"""
|
||||
|
||||
# --8<-- [start:transcription-sampling-params]
|
||||
use_beam_search: bool = False
|
||||
"""Whether or not beam search should be used."""
|
||||
|
||||
n: int = 1
|
||||
"""The number of beams to be used in beam search."""
|
||||
|
||||
length_penalty: float = 1.0
|
||||
"""Length penalty to be used for beam search."""
|
||||
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
@@ -170,6 +183,29 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> BeamSearchParams:
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
max_tokens = default_max_tokens
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# NOTE: Temp 0 is a different fallback than completions
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 0)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
@@ -376,6 +412,18 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
|
||||
# TODO support additional sampling parameters
|
||||
# --8<-- [start:translation-sampling-params]
|
||||
use_beam_search: bool = False
|
||||
"""Whether or not beam search should be used."""
|
||||
|
||||
n: int = 1
|
||||
"""The number of beams to be used in beam search."""
|
||||
|
||||
length_penalty: float = 1.0
|
||||
"""Length penalty to be used for beam search."""
|
||||
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
@@ -424,6 +472,29 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> BeamSearchParams:
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
max_tokens = default_max_tokens
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# NOTE: Temp 0 is a different fallback than completions
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 0)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
|
||||
@@ -39,7 +39,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import (
|
||||
@@ -50,6 +50,7 @@ from vllm.multimodal.audio import split_audio
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
@@ -264,8 +265,6 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
via ``get_language_detection_prompt`` and
|
||||
``parse_language_detection_output``.
|
||||
"""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
prompt = self.model_cls.get_language_detection_prompt(
|
||||
audio_chunk,
|
||||
self.asr_config,
|
||||
@@ -403,6 +402,26 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_decoder_prompt_len(engine_prompts: list[ProcessorInputs]) -> int:
|
||||
"""Get the length of the decoder prompt. Currently we need to offset
|
||||
by the decoder prompt length when running beam search because the mm
|
||||
encoder is not currently cached and runs on decode calls; because of
|
||||
this, we need to make sure the redundant encoder calls won't exceed
|
||||
the context :(
|
||||
|
||||
FIXME (Alex) - this will be removed in the very near future once the
|
||||
encoder/decoder caching is implemented.
|
||||
"""
|
||||
input_len = 0
|
||||
assert len(engine_prompts) > 0
|
||||
first_eng_prompt = engine_prompts[0]
|
||||
|
||||
if first_eng_prompt.get("type") == "enc_dec":
|
||||
first_eng_prompt = cast(EncoderDecoderInputs, first_eng_prompt)
|
||||
input_len = len(first_eng_prompt["decoder_prompt"]["prompt_token_ids"])
|
||||
return input_len
|
||||
|
||||
def _get_verbose_segments(
|
||||
self,
|
||||
tokens: tuple,
|
||||
@@ -481,6 +500,11 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
if request.stream and request.use_beam_search:
|
||||
return self.create_error_response(
|
||||
"Streaming is not currently supported with beam search"
|
||||
)
|
||||
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@@ -526,6 +550,13 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
|
||||
input_len = (
|
||||
OpenAISpeechToText._get_decoder_prompt_len(engine_prompts)
|
||||
if request.use_beam_search
|
||||
else 0
|
||||
)
|
||||
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
|
||||
@@ -533,14 +564,20 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_completion_tokens,
|
||||
0,
|
||||
input_len,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params
|
||||
)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
if request.response_format == "verbose_json":
|
||||
sampling_params.logprobs = 1
|
||||
|
||||
@@ -561,13 +598,22 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
params=sampling_params,
|
||||
request_id=request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
list_result_generator.append(generator)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user