[Frontend] Add Support for MM Encoder/Decoder Beam Search (Offline) (#36153)
Signed-off-by: Alex Brooks <albrooks@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -90,9 +90,9 @@ def run_test(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
||||
def resampled_assets() -> list[tuple[Any, int]]:
|
||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
inputs = []
|
||||
sampled_assets = []
|
||||
for asset in audio_assets:
|
||||
audio, orig_sr = asset.audio_and_sample_rate
|
||||
# Resample to Whisper's expected sample rate (16kHz)
|
||||
@@ -100,8 +100,21 @@ def input_audios() -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
||||
audio = librosa.resample(
|
||||
audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE
|
||||
)
|
||||
sampled_assets.append(
|
||||
(audio, WHISPER_SAMPLE_RATE),
|
||||
)
|
||||
return sampled_assets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def input_audios(
|
||||
resampled_assets,
|
||||
) -> list[tuple[list[str], list[str], list[tuple[Any, int]]]]:
|
||||
inputs = []
|
||||
# audio assets are resampled to WHISPER_SAMPLE_RATE
|
||||
for audio_info in resampled_assets:
|
||||
# vLLM prompts, HF prompts, audio inputs
|
||||
inputs.append(([VLLM_PROMPT], [HF_PROMPT], [(audio, WHISPER_SAMPLE_RATE)]))
|
||||
inputs.append(([VLLM_PROMPT], [HF_PROMPT], [audio_info]))
|
||||
return inputs
|
||||
|
||||
|
||||
@@ -111,6 +124,98 @@ def check_model_available(model: str) -> None:
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("beam_width", [1, 2])
|
||||
def test_beam_search_encoder_decoder(
|
||||
monkeypatch,
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
resampled_assets,
|
||||
) -> None:
|
||||
"""Test beam search with encoder-decoder models (Whisper)."""
|
||||
if current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
|
||||
|
||||
model = "openai/whisper-large-v3-turbo"
|
||||
check_model_available(model)
|
||||
|
||||
hf_prompts = [
|
||||
"<|startoftranscript|>",
|
||||
"<|startoftranscript|>",
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
|
||||
hf_outputs = hf_model.generate_beam_search(
|
||||
hf_prompts,
|
||||
beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
audios=resampled_assets,
|
||||
)
|
||||
|
||||
# Test both explicit encoder/decoder prompts
|
||||
vllm_prompts = [
|
||||
# Implicit encoder/decoder prompt
|
||||
{
|
||||
"prompt": "<|startoftranscript|>",
|
||||
"multi_modal_data": {"audio": resampled_assets[0]},
|
||||
},
|
||||
# Explicit encoder/decover prompt
|
||||
{
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {"audio": resampled_assets[1]},
|
||||
},
|
||||
"decoder_prompt": "<|startoftranscript|>",
|
||||
},
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
max_model_len=448,
|
||||
tensor_parallel_size=1,
|
||||
max_num_seqs=4,
|
||||
limit_mm_per_prompt={"audio": 2},
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(
|
||||
vllm_prompts,
|
||||
beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
for i in range(len(vllm_prompts)):
|
||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||
|
||||
for j, (hf_text, vllm_text) in enumerate(
|
||||
zip(hf_output_texts, vllm_output_texts)
|
||||
):
|
||||
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
|
||||
print(hf_text)
|
||||
print(f">>>{j}-th vllm output:")
|
||||
print(vllm_text)
|
||||
|
||||
# Check that we got the same number of beams
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
|
||||
# For encoder-decoder models, we primarily want to verify that:
|
||||
# 1. Beam search completes without errors
|
||||
# 2. We get the expected number of beams
|
||||
# 3. Outputs are reasonable (non-empty, diverse beams)
|
||||
for j in range(len(vllm_output_ids)):
|
||||
# Check that outputs are not empty
|
||||
assert len(vllm_output_ids[j]) > 0, f"Prompt {i}, beam {j}: empty output"
|
||||
# Check that decoded text is not empty
|
||||
assert len(vllm_output_texts[j].strip()) > 0, (
|
||||
f"Prompt {i}, beam {j}: empty text output"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_language_detection_output():
|
||||
"""Unit test for WhisperForConditionalGeneration.parse_language_detection_output.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user