[ROCm][CI] Fix engine teardown and text normalization to stabilize voxtral test (#37138)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-15 23:49:31 -05:00
committed by GitHub
parent 68e1b711f1
commit d4c57863f7
2 changed files with 68 additions and 24 deletions

View File

@@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from dataclasses import asdict
import pytest
import pytest_asyncio
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
@@ -17,18 +19,21 @@ from vllm.assets.audio import AudioAsset
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
from ....utils import ROCM_ENGINE_KWARGS
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ENGINE_CONFIG = dict(
model=MODEL_NAME,
max_model_len=8192,
max_num_seqs=4,
limit_mm_per_prompt={"audio": 1},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
gpu_memory_utilization=0.9,
)
ENGINE_CONFIG = {
"model": MODEL_NAME,
"max_model_len": 8192,
"max_num_seqs": 4,
"limit_mm_per_prompt": {"audio": 1},
"config_format": "mistral",
"load_format": "mistral",
"tokenizer_mode": "mistral",
"enforce_eager": True,
"gpu_memory_utilization": 0.9,
**ROCM_ENGINE_KWARGS,
}
EXPECTED_TEXT = [
@@ -49,6 +54,14 @@ EXPECTED_TEXT = [
]
def _normalize(texts: list[str]) -> list[str]:
# The model occasionally transcribes "OBS" as "a base hit" and
# "oh, my" as "oh my", but both are acoustically valid. Normalise so
# the assertion is stable across runs and hardware.
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
return texts
@pytest.fixture
def audio_assets() -> list[AudioAsset]:
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
@@ -60,15 +73,27 @@ def tokenizer() -> MistralTokenizer:
@pytest.fixture
def engine() -> LLM:
def engine():
engine_args = EngineArgs(**ENGINE_CONFIG)
return LLM(**asdict(engine_args))
llm = LLM(**asdict(engine_args))
try:
yield llm
finally:
with contextlib.suppress(Exception):
llm.llm_engine.engine_core.shutdown()
import torch
torch.accelerator.empty_cache()
@pytest.fixture
def async_engine() -> AsyncLLM:
@pytest_asyncio.fixture
async def async_engine():
engine_args = AsyncEngineArgs(**ENGINE_CONFIG)
return AsyncLLM.from_engine_args(engine_args)
llm = AsyncLLM.from_engine_args(engine_args)
try:
yield llm
finally:
llm.shutdown()
def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
@@ -108,8 +133,13 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
sampling_params=sampling_params,
)
texts = [out.outputs[0].text for out in outputs]
assert texts == EXPECTED_TEXT
texts = _normalize([out.outputs[0].text for out in outputs])
for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
assert got == expected, (
f"Output mismatch at index {i}:\n"
f" got: {got!r}\n"
f" expected: {expected!r}"
)
@pytest.mark.asyncio
@@ -149,9 +179,17 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
output_tokens_list.append(output_tokens)
texts = [
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE)
for output_tokens in output_tokens_list
]
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
assert texts == EXPECTED_TEXT
texts = _normalize(
[
tokenizer.decode(
output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE
)
for output_tokens in output_tokens_list
]
)
for i, (got, expected) in enumerate(zip(texts, EXPECTED_TEXT)):
assert got == expected, (
f"Output mismatch at index {i}:\n"
f" got: {got!r}\n"
f" expected: {expected!r}"
)