[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
@@ -1010,8 +1010,23 @@ class LLMEngine:
|
||||
self.speculative_config
|
||||
# Organize outputs by [step][sequence group] instead of
|
||||
# [sequence group][step].
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, num_seq_groups=len(seq_group_metadata_list))
|
||||
if self.scheduler_config.is_multi_step:
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, len(seq_group_metadata_list))
|
||||
elif self.speculative_config:
|
||||
# Decodes are multi-steps while prefills are not, outputting at
|
||||
# most 1 token. Separate them so that we can trigger chunk
|
||||
# processing without having to pad or copy over prompts K times
|
||||
# to match decodes structure (costly with prompt_logprobs).
|
||||
num_prefills = sum(sg.is_prompt
|
||||
for sg in seq_group_metadata_list)
|
||||
prefills, decodes = outputs[:num_prefills], outputs[
|
||||
num_prefills:]
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
decodes,
|
||||
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
|
||||
outputs_by_sequence_group = [p.outputs for p in prefills
|
||||
] + outputs_by_sequence_group
|
||||
# We have outputs for multiple steps submitted in a single burst,
|
||||
# so invalidate is_first_step_output.
|
||||
is_first_step_output = None
|
||||
|
||||
Reference in New Issue
Block a user