Implement prompt logprobs & Batched topk for computing logprobs (#1328)
Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,9 @@ from typing import Dict, List, Optional, Union
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
PromptLogprobs = List[Optional[Dict[int, float]]]
|
||||
SampleLogprobs = List[Dict[int, float]]
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
||||
"""Status of a sequence."""
|
||||
@@ -116,7 +119,7 @@ class Sequence:
|
||||
self.block_size = block_size
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
@@ -196,7 +199,7 @@ class Sequence:
|
||||
"""
|
||||
if seq_len is None:
|
||||
seq_len = self.get_len()
|
||||
# Note: HF implementation does not count the EOS token
|
||||
# NOTE: HF implementation does not count the EOS token
|
||||
# towards the length, we align with that here for testing.
|
||||
if (eos_token_id is not None
|
||||
and self.get_last_token_id() == eos_token_id):
|
||||
@@ -238,6 +241,19 @@ class SequenceGroup:
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
self.sampling_params = sampling_params
|
||||
self.arrival_time = arrival_time
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
|
||||
@property
|
||||
def prompt(self) -> str:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
@@ -370,6 +386,22 @@ class SequenceOutputs:
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
|
||||
class SequenceGroupOutputs:
|
||||
"""The model outputs associated with a sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: List[SequenceOutputs],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroupOutputs(samples={self.samples}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs})")
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
SamplerOutput = List[List[SequenceOutputs]]
|
||||
SamplerOutput = List[SequenceGroupOutputs]
|
||||
|
||||
Reference in New Issue
Block a user