OpenAI Compatible Frontend (#116)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from cacheflow.sequence import SequenceGroup, SequenceStatus
|
||||
|
||||
|
||||
class CompletionOutput:
|
||||
@@ -12,19 +12,25 @@ class CompletionOutput:
|
||||
token_ids: List[int],
|
||||
cumulative_logprob: float,
|
||||
logprobs: List[Dict[int, float]],
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
self.index = index
|
||||
self.text = text
|
||||
self.token_ids = token_ids
|
||||
self.cumulative_logprob = cumulative_logprob
|
||||
self.logprobs = logprobs
|
||||
self.finish_reason = finish_reason
|
||||
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"CompletionOutput(index={self.index}, "
|
||||
f"text={self.text!r}, "
|
||||
f"token_ids={self.token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||
f"logprobs={self.logprobs})")
|
||||
f"logprobs={self.logprobs},"
|
||||
f"finish_reason={self.finish_reason})")
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
@@ -35,13 +41,11 @@ class RequestOutput:
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
outputs: List[CompletionOutput],
|
||||
done: bool,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.outputs = outputs
|
||||
self.done = done
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
@@ -57,25 +61,28 @@ class RequestOutput:
|
||||
outputs: List[CompletionOutput] = []
|
||||
for seq in top_n_seqs:
|
||||
logprobs = seq.output_logprobs
|
||||
if seq_group.sampling_params.logprobs == 0:
|
||||
if seq_group.sampling_params.logprobs is None:
|
||||
# NOTE: We need to take care of this case because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
logprobs = {}
|
||||
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
|
||||
output = CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
seq.get_output_token_ids(),
|
||||
seq.get_cumulative_logprob(), logprobs)
|
||||
seq.get_cumulative_logprob(), logprobs,
|
||||
finshed_reason)
|
||||
outputs.append(output)
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = top_n_seqs[0].prompt
|
||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
|
||||
seq_group.is_finished())
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"done={self.done})")
|
||||
f"outputs={self.outputs})")
|
||||
|
||||
def finished(self) -> bool:
|
||||
return all(output.finished() for output in self.outputs)
|
||||
|
||||
Reference in New Issue
Block a user