[ROCm][CI] Fix engine teardown and text normalization to stabilize voxtral test (#37138)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from dataclasses import asdict
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import RawAudio
|
||||
from mistral_common.protocol.transcription.request import (
|
||||
@@ -17,18 +19,21 @@ from vllm.assets.audio import AudioAsset
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
from ....utils import ROCM_ENGINE_KWARGS
|
||||
|
||||
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
|
||||
ENGINE_CONFIG = dict(
|
||||
model=MODEL_NAME,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=4,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
ENGINE_CONFIG = {
|
||||
"model": MODEL_NAME,
|
||||
"max_model_len": 8192,
|
||||
"max_num_seqs": 4,
|
||||
"limit_mm_per_prompt": {"audio": 1},
|
||||
"config_format": "mistral",
|
||||
"load_format": "mistral",
|
||||
"tokenizer_mode": "mistral",
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.9,
|
||||
**ROCM_ENGINE_KWARGS,
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_TEXT = [
|
||||
@@ -49,6 +54,14 @@ EXPECTED_TEXT = [
|
||||
]
|
||||
|
||||
|
||||
def _normalize(texts: list[str]) -> list[str]:
|
||||
# The model occasionally transcribes "OBS" as "a base hit" and
|
||||
# "oh, my" as "oh my", but both are acoustically valid. Normalise so
|
||||
# the assertion is stable across runs and hardware.
|
||||
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
|
||||
return texts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_assets() -> list[AudioAsset]:
|
||||
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
@@ -60,15 +73,27 @@ def tokenizer() -> MistralTokenizer:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine() -> LLM:
|
||||
def engine():
|
||||
engine_args = EngineArgs(**ENGINE_CONFIG)
|
||||
return LLM(**asdict(engine_args))
|
||||
llm = LLM(**asdict(engine_args))
|
||||
try:
|
||||
yield llm
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
import torch
|
||||
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_engine() -> AsyncLLM:
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
engine_args = AsyncEngineArgs(**ENGINE_CONFIG)
|
||||
return AsyncLLM.from_engine_args(engine_args)
|
||||
llm = AsyncLLM.from_engine_args(engine_args)
|
||||
try:
|
||||
yield llm
|
||||
finally:
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
|
||||
@@ -108,8 +133,13 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
texts = [out.outputs[0].text for out in outputs]
|
||||
assert texts == EXPECTED_TEXT
|
||||
texts = _normalize([out.outputs[0].text for out in outputs])
|
||||
for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
|
||||
assert got == expected, (
|
||||
f"Output mismatch at index {i}:\n"
|
||||
f" got: {got!r}\n"
|
||||
f" expected: {expected!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -149,9 +179,17 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
|
||||
|
||||
output_tokens_list.append(output_tokens)
|
||||
|
||||
texts = [
|
||||
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE)
|
||||
for output_tokens in output_tokens_list
|
||||
]
|
||||
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
|
||||
assert texts == EXPECTED_TEXT
|
||||
texts = _normalize(
|
||||
[
|
||||
tokenizer.decode(
|
||||
output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
for output_tokens in output_tokens_list
|
||||
]
|
||||
)
|
||||
for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
|
||||
assert got == expected, (
|
||||
f"Output mismatch at index {i}:\n"
|
||||
f" got: {got!r}\n"
|
||||
f" expected: {expected!r}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user