diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 6f878b275..42a8132ff 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -17,6 +17,7 @@ from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from vllm.benchmarks.datasets import ( AIMODataset, + ASRDataset, BurstGPTDataset, ConversationDataset, InstructCoderDataset, @@ -414,6 +415,12 @@ def get_requests(args, tokenizer): dataset_cls = AIMODataset common_kwargs["dataset_subset"] = None common_kwargs["dataset_split"] = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = ASRDataset + common_kwargs["dataset_subset"] = args.hf_subset + common_kwargs["dataset_split"] = args.hf_split + sample_kwargs["asr_min_audio_len_sec"] = args.asr_min_audio_len_sec + sample_kwargs["asr_max_audio_len_sec"] = args.asr_max_audio_len_sec elif args.dataset_name == "prefix_repetition": dataset_cls = PrefixRepetitionRandomDataset sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len @@ -557,6 +564,7 @@ def validate_args(args): elif args.dataset_path in ( InstructCoderDataset.SUPPORTED_DATASET_PATHS | AIMODataset.SUPPORTED_DATASET_PATHS + | ASRDataset.SUPPORTED_DATASET_PATHS ): assert args.backend == "vllm", ( f"{args.dataset_path} needs to use vllm as the backend." @@ -841,6 +849,20 @@ def add_cli_args(parser: argparse.ArgumentParser): add_random_dataset_base_args(parser) add_random_multimodal_dataset_args(parser) + # ASR dataset + parser.add_argument( + "--asr-min-audio-len-sec", + type=float, + default=0.0, + help="Minimum audio duration in seconds for ASR dataset filtering.", + ) + parser.add_argument( + "--asr-max-audio-len-sec", + type=float, + default=float("inf"), + help="Maximum audio duration in seconds for ASR dataset filtering.", + ) + parser = AsyncEngineArgs.add_cli_args(parser)