[Misc] refactor argument parsing in examples (#16635)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
Reid
2025-04-15 16:05:30 +08:00
committed by GitHub
parent b590adfdc1
commit 6ae996a873
25 changed files with 595 additions and 411 deletions

View File

@@ -4,6 +4,22 @@ from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
return parser
def main(args: dict):
# Pop arguments not used by LLM
max_tokens = args.pop("max_tokens")
@@ -35,23 +51,15 @@ def main(args: dict):
]
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
parser = FlexibleArgumentParser()
# Add engine args
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
parser = create_parser()
args: dict = vars(parser.parse_args())
main(args)