[ASR] Fix audio benchmark and add RTFx metric (#32300)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
Ekagra Ranjan
2026-02-09 05:02:37 -05:00
committed by GitHub
parent 3025b3cebb
commit 1d5922fade
4 changed files with 90 additions and 19 deletions

View File

@@ -30,6 +30,7 @@ th {
| HuggingFace-Other | ✅ | ✅ | `lmms-lab/LLaVA-OneVision-Data`, `Aeala/ShareGPT_Vicuna_unfiltered` |
| HuggingFace-MTBench | ✅ | ✅ | `philschmid/mt-bench` |
| HuggingFace-Blazedit | ✅ | ✅ | `vdaita/edit_5k_char`, `vdaita/edit_10k_char` |
| HuggingFace-ASR | ✅ | ✅ | `openslr/librispeech_asr`, `facebook/voxpopuli`, `LIUM/tedlium`, `edinburghcstr/ami`, `speechcolab/gigaspeech`, `kensho/spgispeech` |
| Spec Bench | ✅ | ✅ | `wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl` |
| Custom | ✅ | ✅ | Local file: `data.jsonl` |
| Custom MM | ✅ | ✅ | Local file: `mm_data.jsonl` |
@@ -299,6 +300,22 @@ vllm bench serve \
--blazedit-max-distance 0.99
```
`openslr/librispeech_asr`, `facebook/voxpopuli`, `LIUM/tedlium`, `edinburghcstr/ami`, `speechcolab/gigaspeech`, `kensho/spgispeech`
```bash
vllm bench serve \
--model openai/whisper-large-v3-turbo \
--backend openai-audio \
--dataset-name hf \
--dataset-path facebook/voxpopuli --hf-subset en --hf-split test --no-stream --trust-remote-code \
--num-prompts 99999999 \
--no-oversample \
--endpoint /v1/audio/transcriptions \
--ready-check-timeout-sec 600 \
--save-result \
--max-concurrency 512
```
#### Running With Sampling Parameters
When using OpenAI-compatible backends such as `vllm`, optional sampling

View File

@@ -1443,6 +1443,20 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
help="Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
)
asr_group = parser.add_argument_group("asr dataset options")
asr_group.add_argument(
"--asr-max-audio-len-sec",
type=float,
default=float("inf"),
help="Maximum audio length in seconds for ASR dataset.",
)
asr_group.add_argument(
"--asr-min-audio-len-sec",
type=float,
default=0.0,
help="Minimum audio length in seconds for ASR dataset.",
)
random_group = parser.add_argument_group("random dataset options")
add_random_dataset_base_args(random_group)
@@ -1744,27 +1758,27 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = VisionArenaDataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
args.hf_subset = None
elif (
args.dataset_path in MMVUDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MMVUDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = MMVUDataset
args.hf_split = "validation"
args.hf_split = args.hf_split if args.hf_split else "validation"
args.hf_subset = None
elif (
args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in InstructCoderDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = InstructCoderDataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
elif (
args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MTBenchDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = MTBenchDataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
elif (
args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
@@ -1780,22 +1794,26 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
or args.hf_name in AIMODataset.SUPPORTED_DATASET_PATHS
):
dataset_class = AIMODataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
elif (
args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501
or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = NextEditPredictionDataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
elif (
args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in ASRDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = ASRDataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
hf_kwargs = {
"asr_min_audio_len_sec": args.asr_min_audio_len_sec,
"asr_max_audio_len_sec": args.asr_max_audio_len_sec,
}
elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
dataset_class = BlazeditDataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
hf_kwargs = {
"min_distance": args.blazedit_min_distance,
"max_distance": args.blazedit_max_distance,
@@ -1805,13 +1823,13 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = MLPerfDataset
args.hf_split = "train"
args.hf_split = args.hf_split if args.hf_split else "train"
elif (
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = MMStarDataset
args.hf_split = "val"
args.hf_split = args.hf_split if args.hf_split else "val"
args.hf_subset = None
else:
supported_datasets = set(
@@ -1847,6 +1865,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
no_stream=args.no_stream,
hf_name=args.hf_name,
disable_shuffle=args.disable_shuffle,
trust_remote_code=args.trust_remote_code,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
@@ -2405,6 +2424,7 @@ class HuggingFaceDataset(BenchmarkDataset):
no_stream: bool = False,
dataset_subset: str | None = None,
hf_name: str | None = None,
trust_remote_code: bool = False,
**kwargs,
) -> None:
super().__init__(dataset_path=dataset_path, **kwargs)
@@ -2413,6 +2433,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self.dataset_subset = dataset_subset
self.load_stream = not no_stream
self.hf_name = hf_name or dataset_path
self.trust_remote_code = trust_remote_code
self.load_data()
def load_data(self) -> None:
@@ -2422,6 +2443,7 @@ class HuggingFaceDataset(BenchmarkDataset):
name=self.dataset_subset,
split=self.dataset_split,
streaming=self.load_stream,
trust_remote_code=self.trust_remote_code,
)
if not getattr(self, "disable_shuffle", False):
self.data = self.data.shuffle(seed=self.random_seed)
@@ -3071,13 +3093,9 @@ class ASRDataset(HuggingFaceDataset):
"kensho/spgispeech",
}
DEFAULT_OUTPUT_LEN = 128
DEFAULT_OUTPUT_LEN = 1024
IS_MULTIMODAL = True
# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
skip_long_audios: bool = True
def sample(
self,
tokenizer: TokenizerLike,
@@ -3088,22 +3106,28 @@ class ASRDataset(HuggingFaceDataset):
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
if "openai" in tokenizer.name_or_path:
prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
else:
prompt = ""
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
ind = 0
skipped = 0
asr_min_audio_len_sec = kwargs.get("asr_min_audio_len_sec")
asr_max_audio_len_sec = kwargs.get("asr_max_audio_len_sec")
durations = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
audio = item["audio"]
y, sr = audio["array"], audio["sampling_rate"]
duration_s = librosa.get_duration(y=y, sr=sr)
# Whisper max supported duration
if self.skip_long_audios and duration_s > 30:
if duration_s < asr_min_audio_len_sec or duration_s > asr_max_audio_len_sec:
skipped += 1
continue
durations.append(duration_s)
mm_content = {"audio": (y, sr)}
sampled_requests.append(
SampleRequest(
@@ -3122,6 +3146,20 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.",
skipped,
)
logger.info("Number of audio samples: %d", len(durations))
avg_duration = sum(durations) / len(durations) if durations else 0
min_duration = min(durations) if durations else 0
max_duration = max(durations) if durations else 0
median_duration = np.median(durations) if durations else 0
logger.info(
"Audio duration statistics (s): avg=%.2f, min=%.2f, max=%.2f, median=%.2f",
avg_duration,
min_duration,
max_duration,
median_duration,
)
self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix, no_oversample
)

View File

@@ -93,6 +93,7 @@ class RequestFuncOutput:
prompt_len: int = 0
error: str = ""
start_time: float = 0.0
input_audio_duration: float = 0.0 # in seconds
class RequestFunc(Protocol):
@@ -422,6 +423,8 @@ async def async_request_openai_audio(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
output.input_audio_duration = soundfile.info(f).duration
f.seek(0)
generated_text = ""
ttft = 0.0
@@ -442,7 +445,9 @@ async def async_request_openai_audio(
messages = handler.add_chunk(chunk_bytes)
for message in messages:
chunk = message.decode("utf-8").removeprefix("data: ")
if type(message) is bytes:
message = message.decode("utf-8")
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)

View File

@@ -193,6 +193,7 @@ class BenchmarkMetrics:
# Max output tokens per second and concurrent requests at that peak
max_output_tokens_per_s: float
max_concurrent_requests: int
rtfx: float = 0.0 # Inverse Real-Time Factor for ASR benchmarks
@dataclass
@@ -412,6 +413,7 @@ def calculate_metrics(
all_tpots: list[float] = []
ttfts: list[float] = []
e2els: list[float] = []
input_audio_duration = 0.0
for i in range(len(outputs)):
if outputs[i].success:
output_len = outputs[i].output_tokens
@@ -439,6 +441,7 @@ def calculate_metrics(
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
e2els.append(outputs[i].latency)
input_audio_duration += outputs[i].input_audio_duration
completed += 1
else:
actual_output_lens.append(0)
@@ -583,6 +586,7 @@ def calculate_metrics(
],
max_output_tokens_per_s=max_output_tokens_per_s,
max_concurrent_requests=max_concurrent_requests,
rtfx=input_audio_duration / dur_s,
)
return metrics, actual_output_lens
@@ -937,6 +941,12 @@ async def benchmark(
"Peak concurrent requests:", metrics.max_concurrent_requests
)
)
if metrics.rtfx > 0.0:
print(
"{:<40} {:<10.2f}".format(
"RTFx (Inverse Real-Time Factor):", metrics.rtfx
)
)
print(
"{:<40} {:<10.2f}".format(
"Total token throughput (tok/s):", metrics.total_token_throughput
@@ -963,6 +973,7 @@ async def benchmark(
"errors": [output.error for output in outputs],
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
"max_concurrent_requests": metrics.max_concurrent_requests,
"rtfx": metrics.rtfx,
}
else:
result = {