[Frontend] Gemma3n audio transcriptions/translations endpoint (#23735)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -12,32 +12,24 @@ import pytest
|
||||
import pytest_asyncio
|
||||
import soundfile as sf
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "openai/whisper-small"
|
||||
SERVER_ARGS = ["--enforce-eager"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def foscolo():
|
||||
# Test translation it->en
|
||||
path = AudioAsset('azacinto_foscolo').get_local_path()
|
||||
with open(str(path), "rb") as f:
|
||||
yield f
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as remote_server:
|
||||
yield remote_server
|
||||
@pytest.fixture(scope="module",
|
||||
params=["openai/whisper-small", "google/gemma-3n-E2B-it"])
|
||||
def server(request):
|
||||
# Parametrize over model name
|
||||
with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server:
|
||||
yield remote_server, request.param
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async def client_and_model(server):
|
||||
server, model_name = server
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
yield async_client, model_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -56,27 +48,29 @@ async def test_non_asr_model(foscolo):
|
||||
|
||||
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio(foscolo, client):
|
||||
async def test_basic_audio(foscolo, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
translation = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
response_format="text",
|
||||
# TODO remove once language detection is implemented
|
||||
extra_body=dict(language="it"),
|
||||
# TODO remove `language="it"` once language detection is implemented
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
assert "greek sea" in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audio_prompt(foscolo, client):
|
||||
async def test_audio_prompt(foscolo, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
# Condition whisper on starting text
|
||||
prompt = "Nor have I ever"
|
||||
transcription = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
prompt=prompt,
|
||||
extra_body=dict(language="it"),
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
@@ -85,22 +79,27 @@ async def test_audio_prompt(foscolo, client):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_response(foscolo, client, server):
|
||||
async def test_streaming_response(foscolo, client_and_model, server):
|
||||
client, model_name = client_and_model
|
||||
translation = ""
|
||||
res_no_stream = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=foscolo,
|
||||
response_format="json",
|
||||
extra_body=dict(language="it"),
|
||||
extra_body=dict(language="it", to_language="en", seed=42),
|
||||
temperature=0.0)
|
||||
|
||||
# Stream via HTTPX since OpenAI translation client doesn't expose streaming
|
||||
server, model_name = server
|
||||
url = server.url_for("v1/audio/translations")
|
||||
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||
data = {
|
||||
"model": MODEL_NAME,
|
||||
"model": model_name,
|
||||
"language": "it",
|
||||
"to_language": "en",
|
||||
"stream": True,
|
||||
"temperature": 0.0,
|
||||
"seed": 42,
|
||||
}
|
||||
foscolo.seek(0)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
@@ -121,16 +120,24 @@ async def test_streaming_response(foscolo, client, server):
|
||||
text = chunk["choices"][0].get("delta", {}).get("content")
|
||||
translation += text or ""
|
||||
|
||||
assert translation == res_no_stream.text
|
||||
res_stream = translation.split()
|
||||
# NOTE There's a small non-deterministic issue here, likely in the attn
|
||||
# computation, which will cause a few tokens to be different, while still
|
||||
# being very close semantically.
|
||||
assert sum([
|
||||
x == y for x, y in zip(res_stream, res_no_stream.text.split())
|
||||
]) >= len(res_stream) * 0.9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_options(foscolo, client, server):
|
||||
async def test_stream_options(foscolo, server):
|
||||
server, model_name = server
|
||||
url = server.url_for("v1/audio/translations")
|
||||
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||
data = {
|
||||
"model": MODEL_NAME,
|
||||
"model": model_name,
|
||||
"language": "it",
|
||||
"to_language": "en",
|
||||
"stream": True,
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True,
|
||||
@@ -164,7 +171,10 @@ async def test_stream_options(foscolo, client, server):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_audio_request(foscolo, client):
|
||||
async def test_long_audio_request(foscolo, client_and_model):
|
||||
client, model_name = client_and_model
|
||||
if model_name == "google/gemma-3n-E2B-it":
|
||||
pytest.skip("Gemma3n does not support long audio requests")
|
||||
foscolo.seek(0)
|
||||
audio, sr = librosa.load(foscolo)
|
||||
repeated_audio = np.tile(audio, 2)
|
||||
@@ -173,9 +183,9 @@ async def test_long_audio_request(foscolo, client):
|
||||
sf.write(buffer, repeated_audio, sr, format='WAV')
|
||||
buffer.seek(0)
|
||||
translation = await client.audio.translations.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
file=buffer,
|
||||
extra_body=dict(language="it"),
|
||||
extra_body=dict(language="it", to_language="en"),
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
|
||||
Reference in New Issue
Block a user