Implement prompt logprobs & Batched topk for computing logprobs (#1328)

Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
Zhuohan Li
2023-10-16 10:56:50 -07:00
committed by GitHub
parent 928de46888
commit 9d9072a069
14 changed files with 369 additions and 130 deletions

View File

@@ -6,6 +6,9 @@ from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]
class SequenceStatus(enum.Enum):
"""Status of a sequence."""
@@ -116,7 +119,7 @@ class Sequence:
self.block_size = block_size
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
@@ -196,7 +199,7 @@ class Sequence:
"""
if seq_len is None:
seq_len = self.get_len()
# Note: HF implementation does not count the EOS token
# NOTE: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
@@ -238,6 +241,19 @@ class SequenceGroup:
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.arrival_time = arrival_time
self.prompt_logprobs: Optional[PromptLogprobs] = None
@property
def prompt(self) -> str:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt
@property
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
@@ -370,6 +386,22 @@ class SequenceOutputs:
and self.logprobs == other.logprobs)
class SequenceGroupOutputs:
"""The model outputs associated with a sequence group."""
def __init__(
self,
samples: List[SequenceOutputs],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:
return (f"SequenceGroupOutputs(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})")
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]]
SamplerOutput = List[SequenceGroupOutputs]