[Benchmark] Add --async-engine option to benchmark_throughput.py (#7964)
This commit is contained in:
@@ -6,13 +6,16 @@ import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@@ -135,6 +138,93 @@ def run_vllm(
|
||||
return end - start
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
@@ -230,7 +320,7 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
run_args = [
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
@@ -240,7 +330,14 @@ def main(args: argparse.Namespace):
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||
args.disable_async_output_proc)
|
||||
args.disable_async_output_proc
|
||||
]
|
||||
|
||||
if args.async_engine:
|
||||
run_args.append(args.disable_frontend_multiprocessing)
|
||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||
else:
|
||||
elapsed_time = run_vllm(*run_args)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@@ -426,6 +523,14 @@ if __name__ == "__main__":
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable async output processor for vLLM backend.")
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
Reference in New Issue
Block a user