[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling (#9038)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2024-10-06 15:48:11 -04:00
committed by GitHub
parent fdf59d30ea
commit cb3b2b9ba4
6 changed files with 179 additions and 110 deletions

View File

@@ -962,6 +962,45 @@ class LLMEngine:
return
def _update_num_computed_tokens_for_multi_step_prefill(
self, seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata,
is_first_step_output: Optional[bool]):
"""
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
in multi-step are submitted in a single burst.
"""
assert self.scheduler_config.is_multi_step
if not seq_group_meta.is_prompt:
# num_computed_token updates for multi-step decodes happen after
# the tokens are appended to the sequence.
return
do_update: bool = False
if self.scheduler_config.chunked_prefill_enabled:
# In multi-step + chunked-prefill case, the prompt sequences
# that are scheduled are fully processed in the first step.
do_update = is_first_step_output is None or is_first_step_output
else:
# Normal multi-step decoding case. In this case prompt-sequences
# are actually single-stepped. Always update in this case.
assert seq_group.state.num_steps == 1
do_update = True
if do_update:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
def _process_model_outputs(self,
ctx: SchedulerContext,
request_id: Optional[str] = None) -> None:
@@ -972,64 +1011,6 @@ class LLMEngine:
request_id: If provided, then only this request is going to be processed
"""
def update_prefill_num_computed_tokens(
seq_group: SequenceGroup,
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
is_first_step_output: Optional[bool]) -> None:
"""
When multi-step and chunked-prefill are enabled together, the
prefill sequence scheduled for multi-step execution turn into
decodes in the first step itself. This function accounts
for that conversion.
seq_group: SequenceGroup - A prefill seq_group
seq_group_meta: SequenceGroupMetadata - Metadata of the given
prefill seq_group
num_outputs: int - number of output tokens being processed for the
given seq_group
is_first_step_output: Optional[bool] -
If multi-step is enabled and num_outputs is 1, this value
indicates if this outputs belongs to the first step in the
multi-step.
If multi-step is enabled and num_outputs > 1, this value
must be None, as num_outputs > 1 indicates that outputs from
all the steps in multi-step are submitted in a single burst.
When multi-step is disabled, this value is always True.
"""
assert seq_group_meta.is_prompt
token_chunk_size = seq_group_meta.token_chunk_size
if num_outputs == 1:
assert is_first_step_output is not None
if seq_group_meta.state.num_steps == 1:
assert is_first_step_output is True
seq_group.update_num_computed_tokens(token_chunk_size)
return
# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
if is_first_step_output is True:
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)
return
assert is_first_step_output is None
# multi-step prefill is only supported when multi-step is
# enabled with chunked prefill. Outputs from all the steps are
# submitted in a single burst.
assert self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled
assert num_outputs == seq_group_meta.state.num_steps, \
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
# This sequence is a prompt during the first step only.
seq_group.update_num_computed_tokens(token_chunk_size)
now = time.time()
if len(ctx.output_queue) == 0:
@@ -1090,7 +1071,7 @@ class LLMEngine:
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group: SequenceGroup = scheduled_seq_group.seq_group
if seq_group.is_finished():
finished_before.append(i)
@@ -1101,14 +1082,14 @@ class LLMEngine:
else:
output = [outputs_by_sequence_group[0][i]]
if not is_async and seq_group_meta.is_prompt:
# Updates for all decodes happen when we actually append the
# token ids to the seq in process_outputs.
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
len(output),
is_first_step_output)
elif not is_async:
seq_group.update_num_computed_tokens(1)
if not is_async:
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
if outputs:
for o in outputs:
@@ -1132,16 +1113,8 @@ class LLMEngine:
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
output_token_num = self.output_processor.process_outputs(
self.output_processor.process_outputs(
seq_group, output, is_async)
if self.speculative_config:
# We -1 here because we always
# (w/o speculative decoding) add the number of
# computed tokens by one in the decoding phase.
# Therefore, we remove that one token that
# is already added.
seq_group.update_num_computed_tokens(output_token_num -
1)
if seq_group.is_finished():
finished_now.append(i)
@@ -1250,20 +1223,15 @@ class LLMEngine:
if seq_group.is_finished():
continue
if seq_group_metadata.is_prompt:
if self.scheduler_config.is_multi_step and \
self.scheduler_config.chunked_prefill_enabled:
# Prompts are scheduled in multi-step only when
# chunking is enabled. These prompts turn into
# decodes after the very first step. Therefore,
# we skip the update to the num_computed_tokens
# here.
seq_group.update_num_computed_tokens(1)
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
seq_group.update_num_computed_tokens(1)
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
@@ -1273,7 +1241,15 @@ class LLMEngine:
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0
seq.append_token_id(sample.output_token, sample.logprobs)
if not is_prefill_append:
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs)
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.