[Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309)
This commit is contained in:
@@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceStage)
|
||||
SequenceGroup, SequenceGroupMetadata)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@@ -476,9 +476,12 @@ class LLMEngine:
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _process_model_outputs(
|
||||
self, output: List[SamplerOutput],
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
||||
self,
|
||||
output: List[SamplerOutput],
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> List[RequestOutput]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
@@ -492,17 +495,15 @@ class LLMEngine:
|
||||
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
||||
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
|
||||
output_by_sequence_group):
|
||||
for scheduled_seq_group, outputs, seq_group_meta in zip(
|
||||
scheduled_seq_groups, output_by_sequence_group,
|
||||
seq_group_metadata_list):
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
|
||||
# If all sequences in the sequence group are in DECODE, then we can
|
||||
# process the output tokens. Otherwise, they are (chunked) prefill
|
||||
# samples and should not be processed.
|
||||
stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
|
||||
if all(stage == SequenceStage.DECODE for stage in stages):
|
||||
self.output_processor.process_prompt_logprob(seq_group, outputs)
|
||||
if seq_group_meta.do_sample:
|
||||
self.output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
@@ -585,7 +586,7 @@ class LLMEngine:
|
||||
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups)
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
|
||||
Reference in New Issue
Block a user