Enable --profile in 'vllm bench throughput' (#24575)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
@@ -37,6 +37,7 @@ def run_vllm(
|
|||||||
requests: list[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
|
do_profile: bool,
|
||||||
disable_detokenize: bool = False,
|
disable_detokenize: bool = False,
|
||||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
@@ -75,10 +76,14 @@ def run_vllm(
|
|||||||
outputs = None
|
outputs = None
|
||||||
if not use_beam_search:
|
if not use_beam_search:
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
if do_profile:
|
||||||
|
llm.start_profile()
|
||||||
outputs = llm.generate(prompts,
|
outputs = llm.generate(prompts,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
lora_request=lora_requests,
|
lora_request=lora_requests,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
|
if do_profile:
|
||||||
|
llm.stop_profile()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
else:
|
else:
|
||||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||||
@@ -88,6 +93,8 @@ def run_vllm(
|
|||||||
for request in requests:
|
for request in requests:
|
||||||
assert request.expected_output_len == output_len
|
assert request.expected_output_len == output_len
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
if do_profile:
|
||||||
|
llm.start_profile()
|
||||||
llm.beam_search(
|
llm.beam_search(
|
||||||
prompts,
|
prompts,
|
||||||
BeamSearchParams(
|
BeamSearchParams(
|
||||||
@@ -95,6 +102,8 @@ def run_vllm(
|
|||||||
max_tokens=output_len,
|
max_tokens=output_len,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
))
|
))
|
||||||
|
if do_profile:
|
||||||
|
llm.stop_profile()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return end - start, outputs
|
return end - start, outputs
|
||||||
|
|
||||||
@@ -103,6 +112,7 @@ def run_vllm_chat(
|
|||||||
requests: list[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: EngineArgs,
|
engine_args: EngineArgs,
|
||||||
|
do_profile: bool,
|
||||||
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
||||||
"""
|
"""
|
||||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||||
@@ -133,7 +143,11 @@ def run_vllm_chat(
|
|||||||
detokenize=not disable_detokenize,
|
detokenize=not disable_detokenize,
|
||||||
))
|
))
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
if do_profile:
|
||||||
|
llm.start_profile()
|
||||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||||
|
if do_profile:
|
||||||
|
llm.stop_profile()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return end - start, outputs
|
return end - start, outputs
|
||||||
|
|
||||||
@@ -142,6 +156,7 @@ async def run_vllm_async(
|
|||||||
requests: list[SampleRequest],
|
requests: list[SampleRequest],
|
||||||
n: int,
|
n: int,
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
|
do_profile: bool,
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
disable_detokenize: bool = False,
|
disable_detokenize: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
@@ -185,6 +200,8 @@ async def run_vllm_async(
|
|||||||
|
|
||||||
generators = []
|
generators = []
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
if do_profile:
|
||||||
|
await llm.start_profile()
|
||||||
for i, (prompt, sp,
|
for i, (prompt, sp,
|
||||||
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
|
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
|
||||||
generator = llm.generate(prompt,
|
generator = llm.generate(prompt,
|
||||||
@@ -195,6 +212,8 @@ async def run_vllm_async(
|
|||||||
all_gens = merge_async_iterators(*generators)
|
all_gens = merge_async_iterators(*generators)
|
||||||
async for i, res in all_gens:
|
async for i, res in all_gens:
|
||||||
pass
|
pass
|
||||||
|
if do_profile:
|
||||||
|
await llm.stop_profile()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
@@ -543,6 +562,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Split of the HF dataset.")
|
help="Split of the HF dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use Torch Profiler. The env variable "
|
||||||
|
"VLLM_TORCH_PROFILER_DIR must be set to enable profiler.")
|
||||||
|
|
||||||
# prefix repetition dataset
|
# prefix repetition dataset
|
||||||
prefix_repetition_group = parser.add_argument_group(
|
prefix_repetition_group = parser.add_argument_group(
|
||||||
@@ -600,22 +625,27 @@ def main(args: argparse.Namespace):
|
|||||||
requests,
|
requests,
|
||||||
args.n,
|
args.n,
|
||||||
AsyncEngineArgs.from_cli_args(args),
|
AsyncEngineArgs.from_cli_args(args),
|
||||||
args.disable_frontend_multiprocessing,
|
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
|
||||||
args.disable_detokenize,
|
disable_detokenize=args.disable_detokenize,
|
||||||
|
do_profile=args.profile,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
elapsed_time, request_outputs = run_vllm(
|
elapsed_time, request_outputs = run_vllm(
|
||||||
requests, args.n, EngineArgs.from_cli_args(args),
|
requests, args.n, EngineArgs.from_cli_args(args),
|
||||||
args.disable_detokenize)
|
disable_detokenize=args.disable_detokenize,
|
||||||
|
do_profile=args.profile)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
|
if args.profile:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Profiling not implemented yet for backend='hf'.")
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
args.hf_max_batch_size, args.trust_remote_code,
|
args.hf_max_batch_size, args.trust_remote_code,
|
||||||
args.disable_detokenize)
|
args.disable_detokenize)
|
||||||
elif args.backend == "vllm-chat":
|
elif args.backend == "vllm-chat":
|
||||||
elapsed_time, request_outputs = run_vllm_chat(
|
elapsed_time, request_outputs = run_vllm_chat(
|
||||||
requests, args.n, EngineArgs.from_cli_args(args),
|
requests, args.n, EngineArgs.from_cli_args(args),
|
||||||
args.disable_detokenize)
|
disable_detokenize=args.disable_detokenize, do_profile=args.profile)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user