[Bugfix] Add missing ASRDataset import and CLI args in benchmarks/throughput.py (#38114)
Signed-off-by: nemanjaudovic <nudovic@amd.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
AIMODataset,
|
||||
ASRDataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
@@ -414,6 +415,12 @@ def get_requests(args, tokenizer):
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ASRDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["asr_min_audio_len_sec"] = args.asr_min_audio_len_sec
|
||||
sample_kwargs["asr_max_audio_len_sec"] = args.asr_max_audio_len_sec
|
||||
elif args.dataset_name == "prefix_repetition":
|
||||
dataset_cls = PrefixRepetitionRandomDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
|
||||
@@ -557,6 +564,7 @@ def validate_args(args):
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
| ASRDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
@@ -841,6 +849,20 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
# ASR dataset
|
||||
parser.add_argument(
|
||||
"--asr-min-audio-len-sec",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Minimum audio duration in seconds for ASR dataset filtering.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--asr-max-audio-len-sec",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Maximum audio duration in seconds for ASR dataset filtering.",
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user