Implement prompt logprobs & Batched topk for computing logprobs (#1328)
Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
|
||||
SequenceStatus)
|
||||
|
||||
|
||||
class CompletionOutput:
|
||||
@@ -23,7 +24,7 @@ class CompletionOutput:
|
||||
text: str,
|
||||
token_ids: List[int],
|
||||
cumulative_logprob: float,
|
||||
logprobs: Optional[List[Dict[int, float]]],
|
||||
logprobs: Optional[SampleLogprobs],
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
self.index = index
|
||||
@@ -61,12 +62,14 @@ class RequestOutput:
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
outputs: List[CompletionOutput],
|
||||
finished: bool,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
|
||||
@@ -91,7 +94,7 @@ class RequestOutput:
|
||||
# NOTE: We need to take care of this case because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
logprobs = {}
|
||||
logprobs = None
|
||||
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
|
||||
output = CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
seq.get_output_token_ids(),
|
||||
@@ -100,15 +103,17 @@ class RequestOutput:
|
||||
outputs.append(output)
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = top_n_seqs[0].prompt
|
||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||
prompt = seq_group.prompt
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
prompt_logprobs = seq_group.prompt_logprobs
|
||||
finished = seq_group.is_finished()
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
|
||||
finished)
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids,
|
||||
prompt_logprobs, outputs, finished)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished})")
|
||||
|
||||
Reference in New Issue
Block a user