fix offline inference chat response prompt (#32088)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ Usage:
|
||||
python examples/offline_inference/context_extension.py
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, RequestOutput, SamplingParams
|
||||
|
||||
|
||||
def create_llm():
|
||||
@@ -45,13 +45,15 @@ def run_llm_chat(llm):
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
]
|
||||
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
|
||||
return outputs
|
||||
return outputs, [
|
||||
conversation,
|
||||
]
|
||||
|
||||
|
||||
def print_outputs(outputs):
|
||||
def print_outputs(outputs: list[RequestOutput], conversations: list):
|
||||
print("\nGenerated Outputs:\n" + "-" * 80)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = conversations[i]
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
@@ -60,8 +62,8 @@ def print_outputs(outputs):
|
||||
|
||||
def main():
|
||||
llm = create_llm()
|
||||
outputs = run_llm_chat(llm)
|
||||
print_outputs(outputs)
|
||||
outputs, conversations = run_llm_chat(llm)
|
||||
print_outputs(outputs, conversations)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user