[ASR] Fix audio benchmark and add RTFx metric (#32300)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
@@ -30,6 +30,7 @@ th {
|
||||
| HuggingFace-Other | ✅ | ✅ | `lmms-lab/LLaVA-OneVision-Data`, `Aeala/ShareGPT_Vicuna_unfiltered` |
|
||||
| HuggingFace-MTBench | ✅ | ✅ | `philschmid/mt-bench` |
|
||||
| HuggingFace-Blazedit | ✅ | ✅ | `vdaita/edit_5k_char`, `vdaita/edit_10k_char` |
|
||||
| HuggingFace-ASR | ✅ | ✅ | `openslr/librispeech_asr`, `facebook/voxpopuli`, `LIUM/tedlium`, `edinburghcstr/ami`, `speechcolab/gigaspeech`, `kensho/spgispeech` |
|
||||
| Spec Bench | ✅ | ✅ | `wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl` |
|
||||
| Custom | ✅ | ✅ | Local file: `data.jsonl` |
|
||||
| Custom MM | ✅ | ✅ | Local file: `mm_data.jsonl` |
|
||||
@@ -299,6 +300,22 @@ vllm bench serve \
|
||||
--blazedit-max-distance 0.99
|
||||
```
|
||||
|
||||
`openslr/librispeech_asr`, `facebook/voxpopuli`, `LIUM/tedlium`, `edinburghcstr/ami`, `speechcolab/gigaspeech`, `kensho/spgispeech`
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--model openai/whisper-large-v3-turbo \
|
||||
--backend openai-audio \
|
||||
--dataset-name hf \
|
||||
--dataset-path facebook/voxpopuli --hf-subset en --hf-split test --no-stream --trust-remote-code \
|
||||
--num-prompts 99999999 \
|
||||
--no-oversample \
|
||||
--endpoint /v1/audio/transcriptions \
|
||||
--ready-check-timeout-sec 600 \
|
||||
--save-result \
|
||||
--max-concurrency 512
|
||||
```
|
||||
|
||||
#### Running With Sampling Parameters
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
|
||||
@@ -1443,6 +1443,20 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
help="Maximum distance for blazedit dataset. Min: 0, Max: 1.0",
|
||||
)
|
||||
|
||||
asr_group = parser.add_argument_group("asr dataset options")
|
||||
asr_group.add_argument(
|
||||
"--asr-max-audio-len-sec",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Maximum audio length in seconds for ASR dataset.",
|
||||
)
|
||||
asr_group.add_argument(
|
||||
"--asr-min-audio-len-sec",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Minimum audio length in seconds for ASR dataset.",
|
||||
)
|
||||
|
||||
random_group = parser.add_argument_group("random dataset options")
|
||||
add_random_dataset_base_args(random_group)
|
||||
|
||||
@@ -1744,27 +1758,27 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
|
||||
or args.hf_name in VisionArenaDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = VisionArenaDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
args.hf_subset = None
|
||||
elif (
|
||||
args.dataset_path in MMVUDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in MMVUDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = MMVUDataset
|
||||
args.hf_split = "validation"
|
||||
args.hf_split = args.hf_split if args.hf_split else "validation"
|
||||
args.hf_subset = None
|
||||
elif (
|
||||
args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = InstructCoderDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
elif (
|
||||
args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in MTBenchDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = MTBenchDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
elif (
|
||||
args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
@@ -1780,22 +1794,26 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
|
||||
or args.hf_name in AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = AIMODataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
elif (
|
||||
args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501
|
||||
or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = NextEditPredictionDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
elif (
|
||||
args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in ASRDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = ASRDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
hf_kwargs = {
|
||||
"asr_min_audio_len_sec": args.asr_min_audio_len_sec,
|
||||
"asr_max_audio_len_sec": args.asr_max_audio_len_sec,
|
||||
}
|
||||
elif args.dataset_path in BlazeditDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_class = BlazeditDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
hf_kwargs = {
|
||||
"min_distance": args.blazedit_min_distance,
|
||||
"max_distance": args.blazedit_max_distance,
|
||||
@@ -1805,13 +1823,13 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
|
||||
or args.hf_name in MLPerfDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = MLPerfDataset
|
||||
args.hf_split = "train"
|
||||
args.hf_split = args.hf_split if args.hf_split else "train"
|
||||
elif (
|
||||
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = MMStarDataset
|
||||
args.hf_split = "val"
|
||||
args.hf_split = args.hf_split if args.hf_split else "val"
|
||||
args.hf_subset = None
|
||||
else:
|
||||
supported_datasets = set(
|
||||
@@ -1847,6 +1865,7 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
|
||||
no_stream=args.no_stream,
|
||||
hf_name=args.hf_name,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@@ -2405,6 +2424,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
no_stream: bool = False,
|
||||
dataset_subset: str | None = None,
|
||||
hf_name: str | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(dataset_path=dataset_path, **kwargs)
|
||||
@@ -2413,6 +2433,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self.dataset_subset = dataset_subset
|
||||
self.load_stream = not no_stream
|
||||
self.hf_name = hf_name or dataset_path
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
@@ -2422,6 +2443,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=self.load_stream,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
if not getattr(self, "disable_shuffle", False):
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
@@ -3071,13 +3093,9 @@ class ASRDataset(HuggingFaceDataset):
|
||||
"kensho/spgispeech",
|
||||
}
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
DEFAULT_OUTPUT_LEN = 1024
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
# TODO Whisper-specific. Abstract interface when more models are supported.
|
||||
TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
|
||||
skip_long_audios: bool = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
@@ -3088,22 +3106,28 @@ class ASRDataset(HuggingFaceDataset):
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
prompt = ASRDataset.TRANSCRIPTION_PREAMBLE
|
||||
if "openai" in tokenizer.name_or_path:
|
||||
prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
|
||||
else:
|
||||
prompt = ""
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
ind = 0
|
||||
skipped = 0
|
||||
asr_min_audio_len_sec = kwargs.get("asr_min_audio_len_sec")
|
||||
asr_max_audio_len_sec = kwargs.get("asr_max_audio_len_sec")
|
||||
durations = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
audio = item["audio"]
|
||||
y, sr = audio["array"], audio["sampling_rate"]
|
||||
duration_s = librosa.get_duration(y=y, sr=sr)
|
||||
# Whisper max supported duration
|
||||
if self.skip_long_audios and duration_s > 30:
|
||||
if duration_s < asr_min_audio_len_sec or duration_s > asr_max_audio_len_sec:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
durations.append(duration_s)
|
||||
mm_content = {"audio": (y, sr)}
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
@@ -3122,6 +3146,20 @@ class ASRDataset(HuggingFaceDataset):
|
||||
" what Whisper supports.",
|
||||
skipped,
|
||||
)
|
||||
|
||||
logger.info("Number of audio samples: %d", len(durations))
|
||||
avg_duration = sum(durations) / len(durations) if durations else 0
|
||||
min_duration = min(durations) if durations else 0
|
||||
max_duration = max(durations) if durations else 0
|
||||
median_duration = np.median(durations) if durations else 0
|
||||
logger.info(
|
||||
"Audio duration statistics (s): avg=%.2f, min=%.2f, max=%.2f, median=%.2f",
|
||||
avg_duration,
|
||||
min_duration,
|
||||
max_duration,
|
||||
median_duration,
|
||||
)
|
||||
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix, no_oversample
|
||||
)
|
||||
|
||||
@@ -93,6 +93,7 @@ class RequestFuncOutput:
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
start_time: float = 0.0
|
||||
input_audio_duration: float = 0.0 # in seconds
|
||||
|
||||
|
||||
class RequestFunc(Protocol):
|
||||
@@ -422,6 +423,8 @@ async def async_request_openai_audio(
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output.input_audio_duration = soundfile.info(f).duration
|
||||
f.seek(0)
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
@@ -442,7 +445,9 @@ async def async_request_openai_audio(
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
chunk = message.decode("utf-8").removeprefix("data: ")
|
||||
if type(message) is bytes:
|
||||
message = message.decode("utf-8")
|
||||
chunk = message.removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
@@ -193,6 +193,7 @@ class BenchmarkMetrics:
|
||||
# Max output tokens per second and concurrent requests at that peak
|
||||
max_output_tokens_per_s: float
|
||||
max_concurrent_requests: int
|
||||
rtfx: float = 0.0 # Inverse Real-Time Factor for ASR benchmarks
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -412,6 +413,7 @@ def calculate_metrics(
|
||||
all_tpots: list[float] = []
|
||||
ttfts: list[float] = []
|
||||
e2els: list[float] = []
|
||||
input_audio_duration = 0.0
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
output_len = outputs[i].output_tokens
|
||||
@@ -439,6 +441,7 @@ def calculate_metrics(
|
||||
itls += outputs[i].itl
|
||||
ttfts.append(outputs[i].ttft)
|
||||
e2els.append(outputs[i].latency)
|
||||
input_audio_duration += outputs[i].input_audio_duration
|
||||
completed += 1
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
@@ -583,6 +586,7 @@ def calculate_metrics(
|
||||
],
|
||||
max_output_tokens_per_s=max_output_tokens_per_s,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
rtfx=input_audio_duration / dur_s,
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@@ -937,6 +941,12 @@ async def benchmark(
|
||||
"Peak concurrent requests:", metrics.max_concurrent_requests
|
||||
)
|
||||
)
|
||||
if metrics.rtfx > 0.0:
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"RTFx (Inverse Real-Time Factor):", metrics.rtfx
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Total token throughput (tok/s):", metrics.total_token_throughput
|
||||
@@ -963,6 +973,7 @@ async def benchmark(
|
||||
"errors": [output.error for output in outputs],
|
||||
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
|
||||
"max_concurrent_requests": metrics.max_concurrent_requests,
|
||||
"rtfx": metrics.rtfx,
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
|
||||
Reference in New Issue
Block a user