[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -1215,3 +1215,91 @@ def tarsier_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
hf_processor.patch_size = vision_encoder_info.get_patch_size()
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def voxtral_patch_hf_runner(hf_model: "HfRunner") -> "HfRunner":
|
||||
"""Patch HfRunner for Voxtral's conversation-based processor.
|
||||
|
||||
Two issues in HfRunner require patching:
|
||||
|
||||
1. VoxtralProcessor requires ``apply_chat_template()`` with conversation
|
||||
dicts (accepting ``url``, ``path``, or ``base64`` audio) rather than
|
||||
the standard ``processor(text=, audio=, sampling_rate=)`` interface.
|
||||
2. HfRunner.get_inputs cannot handle multi-audio per prompt because it
|
||||
mis-unpacks ``[(arr1, sr1), (arr2, sr2)]`` via a ``len == 2`` check.
|
||||
|
||||
We override ``get_inputs`` to build conversation dicts and call
|
||||
``apply_chat_template`` directly, bypassing both issues. We also wrap
|
||||
``model.generate`` to strip prompt tokens before decoding, since
|
||||
HfRunner.generate calls batch_decode on the full sequence (prompt +
|
||||
generated).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
processor = hf_model.processor
|
||||
|
||||
def _audio_to_base64(audio_array, sample_rate: int) -> str:
|
||||
"""Encode a numpy audio array as a base64 WAV string."""
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, audio_array, int(sample_rate), format="WAV")
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
def patched_get_inputs(prompts, images=None, videos=None, audios=None, **kwargs):
|
||||
all_inputs = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
content: list[dict] = []
|
||||
|
||||
if audios is not None and audios[i] is not None:
|
||||
items = audios[i]
|
||||
if not isinstance(items, list):
|
||||
items = [items]
|
||||
for item in items:
|
||||
if isinstance(item, (list, tuple)) and len(item) == 2:
|
||||
arr, sr = item
|
||||
else:
|
||||
arr, sr = item, 16_000
|
||||
content.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"base64": _audio_to_base64(arr, sr),
|
||||
}
|
||||
)
|
||||
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
[{"role": "user", "content": content}]
|
||||
)
|
||||
if hasattr(inputs, "to"):
|
||||
inputs = inputs.to(dtype=hf_model.dtype)
|
||||
all_inputs.append(inputs)
|
||||
|
||||
return all_inputs
|
||||
|
||||
_orig_generate = hf_model.model.generate
|
||||
|
||||
def patched_generate(*args, **kwargs):
|
||||
"""Strip prompt tokens so only generated tokens are decoded."""
|
||||
input_ids = kwargs.get("input_ids")
|
||||
if input_ids is None and args:
|
||||
input_ids = args[0]
|
||||
prompt_len = input_ids.shape[1] if input_ids is not None else 0
|
||||
|
||||
output = _orig_generate(*args, **kwargs)
|
||||
if prompt_len:
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output[:, prompt_len:]
|
||||
else:
|
||||
# GenerateDecoderOnlyOutput - trim sequences but preserve
|
||||
# scores/logits so generate_greedy_logprobs_limit can
|
||||
# extract per-token logprobs.
|
||||
output.sequences = output.sequences[:, prompt_len:]
|
||||
return output
|
||||
|
||||
hf_model.get_inputs = patched_get_inputs # type: ignore[method-assign, assignment]
|
||||
hf_model.model.generate = patched_generate # type: ignore[method-assign]
|
||||
return hf_model
|
||||
|
||||
Reference in New Issue
Block a user