[Misc] Benchmarks for audio models (#16505)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-04-19 11:24:14 +02:00
committed by GitHub
parent 2ef0dc53b8
commit 9d4ca19d50
4 changed files with 199 additions and 5 deletions

View File

@@ -64,6 +64,7 @@ class SampleRequest:
class BenchmarkDataset(ABC):
DEFAULT_SEED = 0
IS_MULTIMODAL = False
def __init__(
self,
@@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
SUPPORTED_DATASET_PATHS = {
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
}
IS_MULTIMODAL = True
def sample(self,
tokenizer: PreTrainedTokenizerBase,
@@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
"lmarena-ai/vision-arena-bench-v0.1":
lambda x: x["turns"][0][0]["content"]
}
IS_MULTIMODAL = True
def sample(
self,
@@ -815,3 +818,80 @@ class AIMODataset(HuggingFaceDataset):
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------
class ASRDataset(HuggingFaceDataset):
"""
Dataset class for processing a ASR dataset for transcription.
Tested on the following set:
+----------------+----------------------------------------+--------------------------+-----------------------------+
| Dataset | Domain | Speaking Style | hf-subset |
+----------------+----------------------------------------+--------------------------+-----------------------------+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
| | | | release3-speaker-adaptation |
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
| AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+
""" # noqa: E501
SUPPORTED_DATASET_PATHS = {
"openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium",
"edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech"
}
DEFAULT_OUTPUT_LEN = 128
IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\
"<|notimestamps|>"
skip_long_audios: bool = True
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs,
) -> list:
import librosa
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
skipped = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
break
audio = item["audio"]
y, sr = audio["array"], audio["sampling_rate"]
duration_s = librosa.get_duration(y=y, sr=sr)
# Whisper max supported duration
if self.skip_long_audios and duration_s > 30:
skipped += 1
continue
mm_content = {"audio": (y, sr)}
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
))
if skipped:
logger.warning("%d samples discarded from dataset due to" \
" their length being greater than" \
" what Whisper supports.", skipped)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests