[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-13 22:04:29 -06:00
committed by GitHub
parent 60ca7981bc
commit de42abb366
11 changed files with 350 additions and 70 deletions

View File

@@ -4,16 +4,18 @@
import json
import pytest
import pytest_asyncio
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from transformers import VoxtralForConditionalGeneration
from vllm.tokenizers.mistral import MistralTokenizer
from ....conftest import AudioTestAssets
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close
from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test
from .vlm_utils import model_utils
MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507"
MISTRAL_FORMAT_ARGS = [
@@ -26,40 +28,21 @@ MISTRAL_FORMAT_ARGS = [
]
@pytest.fixture()
def server(request, audio_assets: AudioTestAssets):
args = [
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}),
] + MISTRAL_FORMAT_ARGS
with RemoteOpenAIServer(
MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"}
) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
def _get_prompt(audio_assets, question):
def _get_prompt(audio_assets: AudioTestAssets, question: str) -> list[int]:
"""Build a token-ID prompt via mistral_common for vLLM offline inference."""
tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME)
audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(len(audio_assets))
Audio.from_file(str(asset.get_local_path()), strict=False)
for asset in audio_assets
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]
text_chunk = TextChunk(text=question)
messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()]
messages = [
UserMessage(content=[*audio_chunks, TextChunk(text=question)]).to_openai()
]
return tokenizer.apply_chat_template(messages=messages)
@@ -77,7 +60,7 @@ def test_models_with_multiple_audios(
vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT)
run_multi_audio_test(
vllm_runner,
[(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])],
[(vllm_prompt, [a.audio_and_sample_rate for a in audio_assets])], # type: ignore[list-item]
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
@@ -86,30 +69,142 @@ def test_models_with_multiple_audios(
)
@pytest.mark.asyncio
async def test_online_serving(client, audio_assets: AudioTestAssets):
"""Exercises online serving with/without chunked prefill enabled."""
def test_online_serving(vllm_runner, audio_assets: AudioTestAssets):
"""Two-layer accuracy and serving validation using Mistral format.
def asset_to_chunk(asset):
1. Offline vLLM greedy output (runs first to avoid CUDA fork issues
with multiprocessing - see vlm_utils/core.py).
2. Online OpenAI-compatible API output must match offline — validates
that the serving path (chat template, audio encoding, tokenization)
does not corrupt anything.
Steps run sequentially so each releases the GPU before the next starts.
"""
question = f"What's happening in these {len(audio_assets)} audio clips?"
max_tokens = 10
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
vllm_prompt = _get_prompt(audio_assets, question)
with vllm_runner(
MODEL_NAME,
dtype="half",
enforce_eager=True,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral",
limit_mm_per_prompt={"audio": len(audio_assets)},
) as vllm_model:
offline_outputs = vllm_model.generate_greedy(
[vllm_prompt],
max_tokens,
audios=[audio_data],
)
offline_text = offline_outputs[0][1]
assert offline_text, "Offline vLLM inference produced empty output"
def _asset_to_openai_chunk(asset):
audio = Audio.from_file(str(asset.get_local_path()), strict=False)
audio.format = "wav"
audio_dict = AudioChunk.from_audio(audio).to_openai()
return audio_dict
return AudioChunk.from_audio(audio).to_openai()
audio_chunks = [asset_to_chunk(asset) for asset in audio_assets]
text = f"What's happening in these {len(audio_assets)} audio clips?"
messages = [
{
"role": "user",
"content": [*audio_chunks, {"type": "text", "text": text}],
"content": [
*[_asset_to_openai_chunk(a) for a in audio_assets],
{"type": "text", "text": question},
],
}
]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME, messages=messages, max_tokens=10
server_args = [
"--enforce-eager",
"--limit-mm-per-prompt",
json.dumps({"audio": len(audio_assets)}),
*MISTRAL_FORMAT_ARGS,
]
with RemoteOpenAIServer(
MODEL_NAME,
server_args,
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"},
) as remote_server:
client = remote_server.get_client()
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=max_tokens,
temperature=0,
)
assert len(completion.choices) == 1
choice = completion.choices[0]
assert choice.finish_reason == "length"
assert choice.message.content == offline_text, (
f"Online serving output does not match offline inference.\n"
f" Online: {choice.message.content!r}\n"
f" Offline: {offline_text!r}"
)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.message.content == "In the first audio clip, you hear a brief"
assert choice.finish_reason == "length"
def test_hf_reference(hf_runner, vllm_runner, audio_assets: AudioTestAssets):
"""Compare vLLM Mistral-format output against HF Transformers reference.
Instead of requiring an exact text match (which is brittle across
attention backends), we compare per-token logprobs using the standard
check_logprobs_close helper: when tokens diverge at a position, each
runner's chosen token must appear in the other's top-k logprobs.
Marked xfail(strict=False) so remaining edge-case mismatches
don't block CI.
"""
question = f"What's happening in these {len(audio_assets)} audio clips?"
max_tokens = 10
num_logprobs = 5
audio_data = [asset.audio_and_sample_rate for asset in audio_assets]
vllm_prompt = _get_prompt(audio_assets, question)
with vllm_runner(
MODEL_NAME,
dtype="half",
enforce_eager=True,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral",
limit_mm_per_prompt={"audio": len(audio_assets)},
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[vllm_prompt],
max_tokens,
num_logprobs,
audios=[audio_data],
)
assert vllm_outputs[0][1], "vLLM inference produced empty output"
with hf_runner(
MODEL_NAME,
dtype="half",
auto_cls=VoxtralForConditionalGeneration,
) as hf_model:
hf_model = model_utils.voxtral_patch_hf_runner(hf_model)
hf_outputs = hf_model.generate_greedy_logprobs_limit(
[question],
max_tokens,
num_logprobs,
audios=[audio_data],
)
assert hf_outputs[0][1], "HF Transformers produced empty output"
print(
f"HF Reference Comparison\n"
f" vLLM: {vllm_outputs[0][1]!r}\n"
f" HF: {hf_outputs[0][1]!r}"
)
check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=hf_outputs,
name_0="vllm",
name_1="hf",
)

View File

@@ -10,6 +10,7 @@ from mistral_common.protocol.transcription.request import (
TranscriptionRequest,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
@@ -26,7 +27,7 @@ ENGINE_CONFIG = dict(
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
gpu_memory_utilization=0.4,
gpu_memory_utilization=0.9,
)
@@ -148,6 +149,9 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
output_tokens_list.append(output_tokens)
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
texts = [
tokenizer.decode(output_tokens, special_token_policy=SpecialTokenPolicy.IGNORE)
for output_tokens in output_tokens_list
]
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
assert texts == EXPECTED_TEXT

View File

@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_processor.patch_size = vision_encoder_info.get_patch_size()
return hf_model
def voxtral_patch_hf_runner(hf_model: "HfRunner") -> "HfRunner":
"""Patch HfRunner for Voxtral's conversation-based processor.
Two issues in HfRunner require patching:
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
We override ``get_inputs`` to build conversation dicts and call
``apply_chat_template`` directly, bypassing both issues. We also wrap
``model.generate`` to strip prompt tokens before decoding, since
HfRunner.generate calls batch_decode on the full sequence (prompt +
generated).
"""
import base64
import io
import soundfile as sf
processor = hf_model.processor
def _audio_to_base64(audio_array, sample_rate: int) -> str:
"""Encode a numpy audio array as a base64 WAV string."""
buf = io.BytesIO()
sf.write(buf, audio_array, int(sample_rate), format="WAV")
return base64.b64encode(buf.getvalue()).decode("ascii")
def patched_get_inputs(prompts, images=None, videos=None, audios=None, **kwargs):
all_inputs = []
for i, prompt in enumerate(prompts):
content: list[dict] = []
if audios is not None and audios[i] is not None:
items = audios[i]
if not isinstance(items, list):
items = [items]
for item in items:
if isinstance(item, (list, tuple)) and len(item) == 2:
arr, sr = item
else:
arr, sr = item, 16_000
content.append(
{
"type": "audio",
"base64": _audio_to_base64(arr, sr),
}
)
content.append({"type": "text", "text": prompt})
inputs = processor.apply_chat_template(
[{"role": "user", "content": content}]
)
if hasattr(inputs, "to"):
inputs = inputs.to(dtype=hf_model.dtype)
all_inputs.append(inputs)
return all_inputs
_orig_generate = hf_model.model.generate
def patched_generate(*args, **kwargs):
"""Strip prompt tokens so only generated tokens are decoded."""
input_ids = kwargs.get("input_ids")
if input_ids is None and args:
input_ids = args[0]
prompt_len = input_ids.shape[1] if input_ids is not None else 0
output = _orig_generate(*args, **kwargs)
if prompt_len:
if isinstance(output, torch.Tensor):
output = output[:, prompt_len:]
else:
# GenerateDecoderOnlyOutput - trim sequences but preserve
# scores/logits so generate_greedy_logprobs_limit can
# extract per-token logprobs.
output.sequences = output.sequences[:, prompt_len:]
return output
hf_model.get_inputs = patched_get_inputs # type: ignore[method-assign, assignment]
hf_model.model.generate = patched_generate # type: ignore[method-assign]
return hf_model