[Feature] Add visionarena offline support for benchmark_throughput (#14654)
Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Signed-off-by: Jennifer Zhao <ai.jenniferzhao@gmail.com> Co-authored-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com> Co-authored-by: Jennifer Zhao <JenZhao@users.noreply.github.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
@@ -11,8 +11,9 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (BurstGPTDataset, RandomDataset, SampleRequest,
|
||||
ShareGPTDataset, SonnetDataset)
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
@@ -23,6 +24,7 @@ from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
@@ -32,7 +34,7 @@ def run_vllm(
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
@@ -66,12 +68,13 @@ def run_vllm(
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
@@ -89,7 +92,46 @@ def run_vllm(
|
||||
ignore_eos=True,
|
||||
))
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests.")
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
start = time.perf_counter()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
@@ -264,6 +306,8 @@ def get_requests(args, tokenizer):
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
@@ -272,6 +316,19 @@ def get_requests(args, tokenizer):
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"hf datasets only are supported by vllm-chat backend")
|
||||
# Choose between VisionArenaDataset and HuggingFaceDataset based on
|
||||
# provided parameters.
|
||||
dataset_cls = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
@@ -290,6 +347,7 @@ def main(args: argparse.Namespace):
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
@@ -301,9 +359,9 @@ def main(args: argparse.Namespace):
|
||||
args.disable_detokenize,
|
||||
))
|
||||
else:
|
||||
elapsed_time = run_vllm(requests, args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@@ -312,20 +370,45 @@ def main(args: argparse.Namespace):
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||
args.output_len)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
if is_multi_modal:
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += len(
|
||||
ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
total_output_tokens += sum(
|
||||
len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len
|
||||
for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
@@ -341,17 +424,100 @@ def main(args: argparse.Namespace):
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf" and args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"When --dataset-name is 'hf', backend must be 'vllm-chat'")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError(
|
||||
"LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
|
||||
None) is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
@@ -419,55 +585,24 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
parser.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next "
|
||||
"release. Please use '--dataset-name' and "
|
||||
"'--dataset-path' in the future runs.",
|
||||
stacklevel=2)
|
||||
args.dataset_path = args.dataset
|
||||
if args.dataset is None and args.dataset_path is None:
|
||||
# for random dataset, the default sampling setting is in
|
||||
# benchmark_dataset.RandomDataset
|
||||
print("When dataset is not set, it will default to random dataset")
|
||||
else:
|
||||
assert args.input_len is None
|
||||
if args.enable_lora:
|
||||
assert args.lora_path is not None
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
elif args.backend == "mii":
|
||||
if args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
validate_args(args)
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user