diff --git a/examples/offline_inference/context_extension.py b/examples/offline_inference/context_extension.py index 67d33e188..fae8590f9 100644 --- a/examples/offline_inference/context_extension.py +++ b/examples/offline_inference/context_extension.py @@ -9,7 +9,7 @@ Usage: python examples/offline_inference/context_extension.py """ -from vllm import LLM, SamplingParams +from vllm import LLM, RequestOutput, SamplingParams def create_llm(): @@ -45,13 +45,15 @@ def run_llm_chat(llm): {"role": "assistant", "content": "Hello! How can I assist you today?"}, ] outputs = llm.chat(conversation, sampling_params, use_tqdm=False) - return outputs + return outputs, [ + conversation, + ] -def print_outputs(outputs): +def print_outputs(outputs: list[RequestOutput], conversations: list): print("\nGenerated Outputs:\n" + "-" * 80) - for output in outputs: - prompt = output.prompt + for i, output in enumerate(outputs): + prompt = conversations[i] generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}\n") print(f"Generated text: {generated_text!r}") @@ -60,8 +62,8 @@ def print_outputs(outputs): def main(): llm = create_llm() - outputs = run_llm_chat(llm) - print_outputs(outputs) + outputs, conversations = run_llm_chat(llm) + print_outputs(outputs, conversations) if __name__ == "__main__": diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 29b2e95d2..a84d5b116 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -152,9 +152,12 @@ def main(args): # print the generated text if args.print_output: - for output in outputs: + for i, output in enumerate(outputs): print("-" * 50) - print(f"prompt: {output.prompt}") + if not args.custom_mm_prompts: + print(f"prompt: {prompts[i].prompt}") + else: + print(f"prompt: {prompts[i]}") print(f"generated text: {output.outputs[0].text}") print("-" * 50)