diff --git a/examples/offline_inference/run_one_batch.py b/examples/offline_inference/run_one_batch.py new file mode 100644 index 000000000..d7692c563 --- /dev/null +++ b/examples/offline_inference/run_one_batch.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm import LLM, EngineArgs +from vllm.config import ProfilerConfig +from vllm.utils.argparse_utils import FlexibleArgumentParser + +DEFAULT_MAX_TOKENS = 16 + + +def create_parser() -> FlexibleArgumentParser: + parser = FlexibleArgumentParser() + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + + batch_group = parser.add_argument_group("Batch parameters") + batch_group.add_argument("--batch-size", type=int, default=1) + batch_group.add_argument("--prompt-size", type=int, default=128) + batch_group.add_argument("--prompt-prefix", type=str, default="Hello, my name is") + + profile_group = parser.add_argument_group("Profiling parameters") + profile_group.add_argument( + "--profile", + choices=["none", "prefill", "decode", "both"], + default="none", + ) + profile_group.add_argument( + "--profile-dir", + type=str, + default="", + help="Required when --profile is not 'none'.", + ) + + return parser + + +def _build_prompt(prefix: str, prompt_size: int) -> str: + if prompt_size <= 0: + return "" + if not prefix: + prefix = " " + if len(prefix) >= prompt_size: + return prefix[:prompt_size] + repeat_count = (prompt_size + len(prefix) - 1) // len(prefix) + return (prefix * repeat_count)[:prompt_size] + + +def _build_profiler_config( + profile: str, profile_dir: str, max_tokens: int +) -> ProfilerConfig | None: + if profile == "none": + return None + if not profile_dir: + raise ValueError("--profile-dir must be set when profiling is enabled.") + if profile == "prefill": + delay_iterations = 0 + max_iterations = 1 + elif profile == "decode": + delay_iterations = 1 + max_iterations = max(1, max_tokens) + else: + delay_iterations = 0 + max_iterations = 0 + + return ProfilerConfig( + profiler="torch", + torch_profiler_dir=profile_dir, + delay_iterations=delay_iterations, + max_iterations=max_iterations, + ) + + +def main(args: dict) -> None: + max_tokens = DEFAULT_MAX_TOKENS + batch_size = args.pop("batch_size") + prompt_size = args.pop("prompt_size") + prompt_prefix = args.pop("prompt_prefix") + profile = args.pop("profile") + profile_dir = args.pop("profile_dir") + + profiler_config = _build_profiler_config(profile, profile_dir, max_tokens) + if profiler_config is not None: + args["profiler_config"] = profiler_config + + llm = LLM(**args) + + sampling_params = llm.get_default_sampling_params() + sampling_params.max_tokens = max_tokens + sampling_params.min_tokens = max_tokens + sampling_params.ignore_eos = True + + prompt = _build_prompt(prompt_prefix, prompt_size) + prompts = [prompt] * batch_size + + if profile != "none": + llm.start_profile() + outputs = llm.generate(prompts, sampling_params) + if profile != "none": + llm.stop_profile() + + print("-" * 50) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Prompt: {output.prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + +if __name__ == "__main__": + parser = create_parser() + main(vars(parser.parse_args()))