[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-13 22:04:29 -06:00
committed by GitHub
parent 60ca7981bc
commit de42abb366
11 changed files with 350 additions and 70 deletions

View File

@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_processor.patch_size = vision_encoder_info.get_patch_size()
return hf_model
def voxtral_patch_hf_runner(hf_model: "HfRunner") -> "HfRunner":
"""Patch HfRunner for Voxtral's conversation-based processor.
Two issues in HfRunner require patching:
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
We override ``get_inputs`` to build conversation dicts and call
``apply_chat_template`` directly, bypassing both issues. We also wrap
``model.generate`` to strip prompt tokens before decoding, since
HfRunner.generate calls batch_decode on the full sequence (prompt +
generated).
"""
import base64
import io
import soundfile as sf
processor = hf_model.processor
def _audio_to_base64(audio_array, sample_rate: int) -> str:
"""Encode a numpy audio array as a base64 WAV string."""
buf = io.BytesIO()
sf.write(buf, audio_array, int(sample_rate), format="WAV")
return base64.b64encode(buf.getvalue()).decode("ascii")
def patched_get_inputs(prompts, images=None, videos=None, audios=None, **kwargs):
all_inputs = []
for i, prompt in enumerate(prompts):
content: list[dict] = []
if audios is not None and audios[i] is not None:
items = audios[i]
if not isinstance(items, list):
items = [items]
for item in items:
if isinstance(item, (list, tuple)) and len(item) == 2:
arr, sr = item
else:
arr, sr = item, 16_000
content.append(
{
"type": "audio",
"base64": _audio_to_base64(arr, sr),
}
)
content.append({"type": "text", "text": prompt})
inputs = processor.apply_chat_template(
[{"role": "user", "content": content}]
)
if hasattr(inputs, "to"):
inputs = inputs.to(dtype=hf_model.dtype)
all_inputs.append(inputs)
return all_inputs
_orig_generate = hf_model.model.generate
def patched_generate(*args, **kwargs):
"""Strip prompt tokens so only generated tokens are decoded."""
input_ids = kwargs.get("input_ids")
if input_ids is None and args:
input_ids = args[0]
prompt_len = input_ids.shape[1] if input_ids is not None else 0
output = _orig_generate(*args, **kwargs)
if prompt_len:
if isinstance(output, torch.Tensor):
output = output[:, prompt_len:]
else:
# GenerateDecoderOnlyOutput - trim sequences but preserve
# scores/logits so generate_greedy_logprobs_limit can
# extract per-token logprobs.
output.sequences = output.sequences[:, prompt_len:]
return output
hf_model.get_inputs = patched_get_inputs # type: ignore[method-assign, assignment]
hf_model.model.generate = patched_generate # type: ignore[method-assign]
return hf_model