[Model] Extend Ultravox to accept audio longer than 30s (#13631)
Signed-off-by: Farzad Abdolhosseini <farzad@fixie.ai>
This commit is contained in:
committed by
GitHub
parent
4a42b9f5d6
commit
80e78d02ac
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -21,6 +23,7 @@ def _test_processing_correctness(
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -123,8 +126,10 @@ def _test_processing_correctness(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert baseline_result == cached_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
assert _drop_mm_kwargs_keys(
|
||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
cached_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
@@ -132,8 +137,10 @@ def _test_processing_correctness(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert baseline_result == baseline_tokenized_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
assert _drop_mm_kwargs_keys(
|
||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
baseline_tokenized_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
@@ -141,8 +148,10 @@ def _test_processing_correctness(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert cached_result == cached_tokenized_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
assert _drop_mm_kwargs_keys(
|
||||
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
cached_tokenized_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@@ -173,7 +182,7 @@ def _test_processing_correctness(
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"fixie-ai/ultravox-v0_4",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
@@ -188,11 +197,19 @@ def test_processing_correctness(
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
ignore_mm_keys = None
|
||||
if 'ultravox' in model_id:
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
# attention_mask lets us ignore the difference.
|
||||
ignore_mm_keys = ['audio_features']
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
)
|
||||
|
||||
|
||||
@@ -221,3 +238,29 @@ def test_processing_correctness_phi3v(
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
)
|
||||
|
||||
|
||||
def _drop_mm_kwargs_keys(result: dict,
|
||||
ignore_mm_keys: Optional[list[str]] = None) -> dict:
|
||||
"""Drop specified keys from result['mm_kwargs'].
|
||||
|
||||
This is mainly to avoid doing exact match of audio_features in ultravox.
|
||||
|
||||
Args:
|
||||
result: Result to drop keys from
|
||||
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
|
||||
"""
|
||||
if not ignore_mm_keys:
|
||||
return result
|
||||
|
||||
if 'mm_kwargs' in result:
|
||||
result = copy.deepcopy(result)
|
||||
mm_kwargs = result['mm_kwargs']
|
||||
for key in ignore_mm_keys:
|
||||
mm_kwargs.pop(key, None)
|
||||
for items in mm_kwargs._items_by_modality.values():
|
||||
for item in items:
|
||||
for key in ignore_mm_keys:
|
||||
item.pop(key, None)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user