[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling (#9038)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
fdf59d30ea
commit
cb3b2b9ba4
@@ -962,6 +962,45 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _update_num_computed_tokens_for_multi_step_prefill(
|
||||
self, seq_group: SequenceGroup,
|
||||
seq_group_meta: SequenceGroupMetadata,
|
||||
is_first_step_output: Optional[bool]):
|
||||
"""
|
||||
This function updates num_computed_tokens for prompt sequences
|
||||
when Multi-Step is enabled.
|
||||
|
||||
seq_group: SequenceGroup to update the num_computed_tokens for.
|
||||
seq_group_meta: Metadata of the given SequenceGroup.
|
||||
is_first_step_output: Optional[bool] -
|
||||
When available, is_first_step_output indicates if the appended
|
||||
output token is the output of the first-step in multi-step.
|
||||
A value of None indicates that outputs from all steps in
|
||||
in multi-step are submitted in a single burst.
|
||||
"""
|
||||
|
||||
assert self.scheduler_config.is_multi_step
|
||||
|
||||
if not seq_group_meta.is_prompt:
|
||||
# num_computed_token updates for multi-step decodes happen after
|
||||
# the tokens are appended to the sequence.
|
||||
return
|
||||
|
||||
do_update: bool = False
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
# In multi-step + chunked-prefill case, the prompt sequences
|
||||
# that are scheduled are fully processed in the first step.
|
||||
do_update = is_first_step_output is None or is_first_step_output
|
||||
else:
|
||||
# Normal multi-step decoding case. In this case prompt-sequences
|
||||
# are actually single-stepped. Always update in this case.
|
||||
assert seq_group.state.num_steps == 1
|
||||
do_update = True
|
||||
|
||||
if do_update:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size)
|
||||
|
||||
def _process_model_outputs(self,
|
||||
ctx: SchedulerContext,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
@@ -972,64 +1011,6 @@ class LLMEngine:
|
||||
request_id: If provided, then only this request is going to be processed
|
||||
"""
|
||||
|
||||
def update_prefill_num_computed_tokens(
|
||||
seq_group: SequenceGroup,
|
||||
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
|
||||
is_first_step_output: Optional[bool]) -> None:
|
||||
"""
|
||||
When multi-step and chunked-prefill are enabled together, the
|
||||
prefill sequence scheduled for multi-step execution turn into
|
||||
decodes in the first step itself. This function accounts
|
||||
for that conversion.
|
||||
|
||||
seq_group: SequenceGroup - A prefill seq_group
|
||||
seq_group_meta: SequenceGroupMetadata - Metadata of the given
|
||||
prefill seq_group
|
||||
num_outputs: int - number of output tokens being processed for the
|
||||
given seq_group
|
||||
is_first_step_output: Optional[bool] -
|
||||
If multi-step is enabled and num_outputs is 1, this value
|
||||
indicates if this outputs belongs to the first step in the
|
||||
multi-step.
|
||||
If multi-step is enabled and num_outputs > 1, this value
|
||||
must be None, as num_outputs > 1 indicates that outputs from
|
||||
all the steps in multi-step are submitted in a single burst.
|
||||
When multi-step is disabled, this value is always True.
|
||||
"""
|
||||
|
||||
assert seq_group_meta.is_prompt
|
||||
|
||||
token_chunk_size = seq_group_meta.token_chunk_size
|
||||
|
||||
if num_outputs == 1:
|
||||
assert is_first_step_output is not None
|
||||
|
||||
if seq_group_meta.state.num_steps == 1:
|
||||
assert is_first_step_output is True
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
return
|
||||
|
||||
# multi-step prefill is only supported when multi-step is
|
||||
# enabled with chunked prefill
|
||||
assert self.scheduler_config.is_multi_step and \
|
||||
self.scheduler_config.chunked_prefill_enabled
|
||||
if is_first_step_output is True:
|
||||
# This sequence is a prompt during the first step only.
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
return
|
||||
|
||||
assert is_first_step_output is None
|
||||
|
||||
# multi-step prefill is only supported when multi-step is
|
||||
# enabled with chunked prefill. Outputs from all the steps are
|
||||
# submitted in a single burst.
|
||||
assert self.scheduler_config.is_multi_step and \
|
||||
self.scheduler_config.chunked_prefill_enabled
|
||||
assert num_outputs == seq_group_meta.state.num_steps, \
|
||||
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
|
||||
# This sequence is a prompt during the first step only.
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
now = time.time()
|
||||
|
||||
if len(ctx.output_queue) == 0:
|
||||
@@ -1090,7 +1071,7 @@ class LLMEngine:
|
||||
seq_group_meta = seq_group_metadata_list[i]
|
||||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
|
||||
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group: SequenceGroup = scheduled_seq_group.seq_group
|
||||
|
||||
if seq_group.is_finished():
|
||||
finished_before.append(i)
|
||||
@@ -1101,14 +1082,14 @@ class LLMEngine:
|
||||
else:
|
||||
output = [outputs_by_sequence_group[0][i]]
|
||||
|
||||
if not is_async and seq_group_meta.is_prompt:
|
||||
# Updates for all decodes happen when we actually append the
|
||||
# token ids to the seq in process_outputs.
|
||||
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
|
||||
len(output),
|
||||
is_first_step_output)
|
||||
elif not is_async:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
if not is_async:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
# Updates happen only if the sequence is prefill
|
||||
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||
seq_group, seq_group_meta, is_first_step_output)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
@@ -1132,16 +1113,8 @@ class LLMEngine:
|
||||
else:
|
||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||
if seq_group_meta.do_sample:
|
||||
output_token_num = self.output_processor.process_outputs(
|
||||
self.output_processor.process_outputs(
|
||||
seq_group, output, is_async)
|
||||
if self.speculative_config:
|
||||
# We -1 here because we always
|
||||
# (w/o speculative decoding) add the number of
|
||||
# computed tokens by one in the decoding phase.
|
||||
# Therefore, we remove that one token that
|
||||
# is already added.
|
||||
seq_group.update_num_computed_tokens(output_token_num -
|
||||
1)
|
||||
|
||||
if seq_group.is_finished():
|
||||
finished_now.append(i)
|
||||
@@ -1250,20 +1223,15 @@ class LLMEngine:
|
||||
if seq_group.is_finished():
|
||||
continue
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
if self.scheduler_config.is_multi_step and \
|
||||
self.scheduler_config.chunked_prefill_enabled:
|
||||
# Prompts are scheduled in multi-step only when
|
||||
# chunking is enabled. These prompts turn into
|
||||
# decodes after the very first step. Therefore,
|
||||
# we skip the update to the num_computed_tokens
|
||||
# here.
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
if self.scheduler_config.is_multi_step:
|
||||
# Updates happen only if the sequence is prefill
|
||||
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||
seq_group, seq_group_metadata,
|
||||
seq_group.state.num_steps == 1)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_metadata.token_chunk_size)
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
"Async output processor expects a single sample"
|
||||
@@ -1273,7 +1241,15 @@ class LLMEngine:
|
||||
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
|
||||
if self.scheduler_config.is_multi_step:
|
||||
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
||||
) == 0
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
if not is_prefill_append:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
else:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
Reference in New Issue
Block a user