[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)

This commit is contained in:
SangBin Cho
2024-04-26 22:02:02 +09:00
committed by GitHub
parent a88081bf76
commit 603ad84815
18 changed files with 859 additions and 630 deletions

View File

@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup, SequenceStage)
SequenceGroup, SequenceGroupMetadata)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -476,9 +476,12 @@ class LLMEngine:
return self.scheduler.has_unfinished_seqs()
def _process_model_outputs(
self, output: List[SamplerOutput],
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
self,
output: List[SamplerOutput],
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
@@ -492,17 +495,15 @@ class LLMEngine:
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
output_by_sequence_group):
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
# If all sequences in the sequence group are in DECODE, then we can
# process the output tokens. Otherwise, they are (chunked) prefill
# samples and should not be processed.
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_prompt_logprob(seq_group, outputs)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
@@ -585,7 +586,7 @@ class LLMEngine:
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats.
if self.log_stats: