[Bugfix] Fix a bug in RequestOutput.finished (#202)
This commit is contained in:
@@ -60,11 +60,13 @@ class RequestOutput:
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
outputs: List[CompletionOutput],
|
||||
finished: bool,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
@@ -95,13 +97,13 @@ class RequestOutput:
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = top_n_seqs[0].prompt
|
||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
|
||||
finished = seq_group.is_finished()
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
|
||||
finished)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"outputs={self.outputs})")
|
||||
|
||||
def finished(self) -> bool:
|
||||
return all(output.finished() for output in self.outputs)
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished})")
|
||||
|
||||
Reference in New Issue
Block a user