[Frontend] Add Support for MM Encoder/Decoder Beam Search (Online Transcriptions) (#36160)
Signed-off-by: Alex Brooks <albrooks@redhat.com>
This commit is contained in:
@@ -317,3 +317,72 @@ async def test_language_auto_detect(
|
||||
assert any(word.lower() in text_lower for word in expected_text), (
|
||||
f"Expected {expected_lang} text but got: {transcription.text}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whisper_beam_search_single_beam(mary_had_lamb, whisper_client):
|
||||
"""Test beam search with encoder-decoder model (Whisper) on transcriptions with
|
||||
one beam aligns with greedy decoding.
|
||||
"""
|
||||
beam_transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
use_beam_search=True,
|
||||
n=1,
|
||||
),
|
||||
)
|
||||
|
||||
greedy_transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
greedy_res = json.loads(greedy_transcription)["text"]
|
||||
beam_res = json.loads(beam_transcription)["text"]
|
||||
assert greedy_res == beam_res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whisper_beam_search_multibeam(mary_had_lamb, whisper_client):
|
||||
"""Test n>1 for beam search returns one transcription (best beam)."""
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=mary_had_lamb,
|
||||
language="en",
|
||||
response_format="text",
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
use_beam_search=True,
|
||||
n=2,
|
||||
),
|
||||
)
|
||||
|
||||
result = json.loads(transcription)
|
||||
|
||||
text = result["text"]
|
||||
|
||||
assert text is not None
|
||||
assert len(text) > 0
|
||||
assert "mary had a little lamb" in text.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_beams_raises(winning_call, whisper_client):
|
||||
"""Test that stream=True + beam search raises bad request for now."""
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=winning_call,
|
||||
language="en",
|
||||
stream=True,
|
||||
extra_body=dict(
|
||||
use_beam_search=True,
|
||||
n=2,
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user