Implement prompt logprobs & Batched topk for computing logprobs (#1328)

Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
Zhuohan Li
2023-10-16 10:56:50 -07:00
committed by GitHub
parent 928de46888
commit 9d9072a069
14 changed files with 369 additions and 130 deletions

View File

@@ -12,8 +12,8 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)
SequenceGroupMetadata, SequenceGroupOutputs,
SequenceOutputs, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
@@ -350,9 +350,15 @@ class LLMEngine:
eos_token_id=self.tokenizer.eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples(
self, seq_group: SequenceGroup,
samples: List[SequenceOutputs]) -> None:
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutputs) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
@@ -520,8 +526,8 @@ class LLMEngine:
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples)
for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()