Enable --profile in 'vllm bench throughput' (#24575)

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
Tomas Ruiz
2025-09-11 08:06:19 +02:00
committed by GitHub
parent 3d1393f6fc
commit ee0bc5e1b4

View File

@@ -37,6 +37,7 @@ def run_vllm(
requests: list[SampleRequest], requests: list[SampleRequest],
n: int, n: int,
engine_args: EngineArgs, engine_args: EngineArgs,
do_profile: bool,
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> tuple[float, Optional[list[RequestOutput]]]: ) -> tuple[float, Optional[list[RequestOutput]]]:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
@@ -75,10 +76,14 @@ def run_vllm(
outputs = None outputs = None
if not use_beam_search: if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
if do_profile:
llm.start_profile()
outputs = llm.generate(prompts, outputs = llm.generate(prompts,
sampling_params, sampling_params,
lora_request=lora_requests, lora_request=lora_requests,
use_tqdm=True) use_tqdm=True)
if do_profile:
llm.stop_profile()
end = time.perf_counter() end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA" assert lora_requests is None, "BeamSearch API does not support LoRA"
@@ -88,6 +93,8 @@ def run_vllm(
for request in requests: for request in requests:
assert request.expected_output_len == output_len assert request.expected_output_len == output_len
start = time.perf_counter() start = time.perf_counter()
if do_profile:
llm.start_profile()
llm.beam_search( llm.beam_search(
prompts, prompts,
BeamSearchParams( BeamSearchParams(
@@ -95,6 +102,8 @@ def run_vllm(
max_tokens=output_len, max_tokens=output_len,
ignore_eos=True, ignore_eos=True,
)) ))
if do_profile:
llm.stop_profile()
end = time.perf_counter() end = time.perf_counter()
return end - start, outputs return end - start, outputs
@@ -103,6 +112,7 @@ def run_vllm_chat(
requests: list[SampleRequest], requests: list[SampleRequest],
n: int, n: int,
engine_args: EngineArgs, engine_args: EngineArgs,
do_profile: bool,
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
""" """
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
@@ -133,7 +143,11 @@ def run_vllm_chat(
detokenize=not disable_detokenize, detokenize=not disable_detokenize,
)) ))
start = time.perf_counter() start = time.perf_counter()
if do_profile:
llm.start_profile()
outputs = llm.chat(prompts, sampling_params, use_tqdm=True) outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
if do_profile:
llm.stop_profile()
end = time.perf_counter() end = time.perf_counter()
return end - start, outputs return end - start, outputs
@@ -142,6 +156,7 @@ async def run_vllm_async(
requests: list[SampleRequest], requests: list[SampleRequest],
n: int, n: int,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
do_profile: bool,
disable_frontend_multiprocessing: bool = False, disable_frontend_multiprocessing: bool = False,
disable_detokenize: bool = False, disable_detokenize: bool = False,
) -> float: ) -> float:
@@ -185,6 +200,8 @@ async def run_vllm_async(
generators = [] generators = []
start = time.perf_counter() start = time.perf_counter()
if do_profile:
await llm.start_profile()
for i, (prompt, sp, for i, (prompt, sp,
lr) in enumerate(zip(prompts, sampling_params, lora_requests)): lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
generator = llm.generate(prompt, generator = llm.generate(prompt,
@@ -195,6 +212,8 @@ async def run_vllm_async(
all_gens = merge_async_iterators(*generators) all_gens = merge_async_iterators(*generators)
async for i, res in all_gens: async for i, res in all_gens:
pass pass
if do_profile:
await llm.stop_profile()
end = time.perf_counter() end = time.perf_counter()
return end - start return end - start
@@ -543,6 +562,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=str, type=str,
default=None, default=None,
help="Split of the HF dataset.") help="Split of the HF dataset.")
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="Use Torch Profiler. The env variable "
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.")
# prefix repetition dataset # prefix repetition dataset
prefix_repetition_group = parser.add_argument_group( prefix_repetition_group = parser.add_argument_group(
@@ -600,22 +625,27 @@ def main(args: argparse.Namespace):
requests, requests,
args.n, args.n,
AsyncEngineArgs.from_cli_args(args), AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing, disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
args.disable_detokenize, disable_detokenize=args.disable_detokenize,
do_profile=args.profile,
)) ))
else: else:
elapsed_time, request_outputs = run_vllm( elapsed_time, request_outputs = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args), requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize) disable_detokenize=args.disable_detokenize,
do_profile=args.profile)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
if args.profile:
raise NotImplementedError(
"Profiling not implemented yet for backend='hf'.")
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.hf_max_batch_size, args.trust_remote_code, args.hf_max_batch_size, args.trust_remote_code,
args.disable_detokenize) args.disable_detokenize)
elif args.backend == "vllm-chat": elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat( elapsed_time, request_outputs = run_vllm_chat(
requests, args.n, EngineArgs.from_cli_args(args), requests, args.n, EngineArgs.from_cli_args(args),
args.disable_detokenize) disable_detokenize=args.disable_detokenize, do_profile=args.profile)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")