[V0 Deprecation] Remove multi-step scheduling (#22138)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
@@ -25,7 +25,6 @@ from vllm.engine.metrics_types import StatLoggerBase, Stats
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
@@ -91,7 +90,7 @@ class OutputData(NamedTuple):
|
||||
|
||||
class SchedulerContext:
|
||||
|
||||
def __init__(self, multi_step_stream_outputs: bool = False):
|
||||
def __init__(self) -> None:
|
||||
self.output_queue: Deque[OutputData] = deque()
|
||||
self.request_outputs: List[Union[RequestOutput,
|
||||
PoolingRequestOutput]] = []
|
||||
@@ -99,8 +98,6 @@ class SchedulerContext:
|
||||
List[SequenceGroupMetadata]] = None
|
||||
self.scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
|
||||
|
||||
def append_output(self, outputs: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
||||
@@ -303,8 +300,7 @@ class LLMEngine:
|
||||
]
|
||||
|
||||
self.scheduler_contexts = [
|
||||
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
|
||||
multi_step_stream_outputs)
|
||||
SchedulerContext()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
@@ -683,8 +679,7 @@ class LLMEngine:
|
||||
"Priority scheduling is not enabled.")
|
||||
|
||||
if isinstance(params, SamplingParams) \
|
||||
and params.logits_processors \
|
||||
and self.scheduler_config.num_scheduler_steps > 1:
|
||||
and params.logits_processors:
|
||||
raise ValueError(
|
||||
"Logits processors are not supported in multi-step decoding")
|
||||
|
||||
@@ -868,45 +863,6 @@ class LLMEngine:
|
||||
|
||||
return
|
||||
|
||||
def _update_num_computed_tokens_for_multi_step_prefill(
|
||||
self, seq_group: SequenceGroup,
|
||||
seq_group_meta: SequenceGroupMetadata,
|
||||
is_first_step_output: Optional[bool]):
|
||||
"""
|
||||
This function updates num_computed_tokens for prompt sequences
|
||||
when Multi-Step is enabled.
|
||||
|
||||
seq_group: SequenceGroup to update the num_computed_tokens for.
|
||||
seq_group_meta: Metadata of the given SequenceGroup.
|
||||
is_first_step_output: Optional[bool] -
|
||||
When available, is_first_step_output indicates if the appended
|
||||
output token is the output of the first-step in multi-step.
|
||||
A value of None indicates that outputs from all steps in
|
||||
in multi-step are submitted in a single burst.
|
||||
"""
|
||||
|
||||
assert self.scheduler_config.is_multi_step
|
||||
|
||||
if not seq_group_meta.is_prompt:
|
||||
# num_computed_token updates for multi-step decodes happen after
|
||||
# the tokens are appended to the sequence.
|
||||
return
|
||||
|
||||
do_update: bool = False
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
# In multi-step + chunked-prefill case, the prompt sequences
|
||||
# that are scheduled are fully processed in the first step.
|
||||
do_update = is_first_step_output is None or is_first_step_output
|
||||
else:
|
||||
# Normal multi-step decoding case. In this case prompt-sequences
|
||||
# are actually single-stepped. Always update in this case.
|
||||
assert seq_group.state.num_steps == 1
|
||||
do_update = True
|
||||
|
||||
if do_update:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size)
|
||||
|
||||
def _process_model_outputs(self,
|
||||
ctx: SchedulerContext,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
@@ -939,33 +895,8 @@ class LLMEngine:
|
||||
|
||||
has_multiple_outputs: bool = len(outputs) > 1
|
||||
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
|
||||
if has_multiple_outputs:
|
||||
assert self.scheduler_config.is_multi_step or \
|
||||
self.speculative_config
|
||||
# Organize outputs by [step][sequence group] instead of
|
||||
# [sequence group][step].
|
||||
if self.scheduler_config.is_multi_step:
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
outputs, len(seq_group_metadata_list))
|
||||
elif self.speculative_config:
|
||||
# Decodes are multi-steps while prefills are not, outputting at
|
||||
# most 1 token. Separate them so that we can trigger chunk
|
||||
# processing without having to pad or copy over prompts K times
|
||||
# to match decodes structure (costly with prompt_logprobs).
|
||||
num_prefills = sum(sg.is_prompt
|
||||
for sg in seq_group_metadata_list)
|
||||
prefills, decodes = outputs[:num_prefills], outputs[
|
||||
num_prefills:]
|
||||
outputs_by_sequence_group = create_output_by_sequence_group(
|
||||
decodes,
|
||||
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
|
||||
outputs_by_sequence_group = [p.outputs for p in prefills
|
||||
] + outputs_by_sequence_group
|
||||
# We have outputs for multiple steps submitted in a single burst,
|
||||
# so invalidate is_first_step_output.
|
||||
is_first_step_output = None
|
||||
else:
|
||||
outputs_by_sequence_group = outputs
|
||||
assert not has_multiple_outputs
|
||||
outputs_by_sequence_group = outputs
|
||||
|
||||
# Determine the requests we need to operate on
|
||||
if request_id:
|
||||
@@ -1006,13 +937,8 @@ class LLMEngine:
|
||||
output = [outputs_by_sequence_group[0][i]]
|
||||
|
||||
if not is_async:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
# Updates happen only if the sequence is prefill
|
||||
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||
seq_group, seq_group_meta, is_first_step_output)
|
||||
else:
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size or 0)
|
||||
seq_group.update_num_computed_tokens(
|
||||
seq_group_meta.token_chunk_size or 0)
|
||||
|
||||
if outputs:
|
||||
for o in outputs:
|
||||
@@ -1074,15 +1000,6 @@ class LLMEngine:
|
||||
for scheduler in self.scheduler:
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
# For multi-step without streaming, don't create outputs each iteration
|
||||
if not is_last_step and not ctx.multi_step_stream_outputs:
|
||||
# Immediately process request outputs here (if callback is given)
|
||||
if (finished_now
|
||||
and self.process_request_outputs_callback is not None):
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
ctx.request_outputs.clear()
|
||||
return
|
||||
|
||||
# Create the outputs
|
||||
for i in indices:
|
||||
if i in skip or i in finished_before or i in finished_now:
|
||||
@@ -1101,13 +1018,7 @@ class LLMEngine:
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
# For multi-step with streaming, create outputs each iteration
|
||||
if not is_last_step and ctx.multi_step_stream_outputs:
|
||||
# Immediately process request outputs here (if callback is given)
|
||||
if self.process_request_outputs_callback is not None:
|
||||
self.process_request_outputs_callback(ctx.request_outputs)
|
||||
ctx.request_outputs.clear()
|
||||
return
|
||||
# Create outputs only after processing the scheduler's results
|
||||
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||
params = seq_group.sampling_params
|
||||
@@ -1157,16 +1068,10 @@ class LLMEngine:
|
||||
if seq_group.is_finished():
|
||||
continue
|
||||
|
||||
if self.scheduler_config.is_multi_step:
|
||||
# Updates happen only if the sequence is prefill
|
||||
self._update_num_computed_tokens_for_multi_step_prefill(
|
||||
seq_group, seq_group_metadata,
|
||||
seq_group.state.num_steps == 1)
|
||||
else:
|
||||
token_chunk_size = (seq_group_metadata.token_chunk_size
|
||||
if seq_group_metadata.token_chunk_size
|
||||
is not None else 0)
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
token_chunk_size = (seq_group_metadata.token_chunk_size
|
||||
if seq_group_metadata.token_chunk_size
|
||||
is not None else 0)
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
if seq_group_metadata.do_sample:
|
||||
assert len(sequence_group_outputs.samples) == 1, (
|
||||
@@ -1177,16 +1082,8 @@ class LLMEngine:
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
|
||||
if self.scheduler_config.is_multi_step:
|
||||
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
||||
) == 0
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
if not is_prefill_append:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
else:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
@@ -1289,13 +1186,6 @@ class LLMEngine:
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
else:
|
||||
finished_requests_ids = list()
|
||||
|
||||
@@ -1345,10 +1235,6 @@ class LLMEngine:
|
||||
# Raise so the caller is notified that this request failed
|
||||
raise
|
||||
|
||||
# We need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, outputs)
|
||||
else:
|
||||
# Nothing scheduled => If there is pending async postprocessor,
|
||||
# then finish it here.
|
||||
@@ -1357,19 +1243,9 @@ class LLMEngine:
|
||||
# No outputs in this case
|
||||
outputs = []
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
for seq_group in seq_group_metadata_list:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# clear the cache if we have finished all the steps.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[0] = SchedulerOutputState()
|
||||
|
||||
# is_first_step_output is True only when the num_steps of all
|
||||
# the sequences are 1. When the num_steps > 1,
|
||||
# multi_step_model_runner does the first-step output append.
|
||||
# the sequences are 1.
|
||||
is_first_step_output: bool = False if not seq_group_metadata_list \
|
||||
else seq_group_metadata_list[0].state.num_steps == 1
|
||||
|
||||
@@ -1453,22 +1329,7 @@ class LLMEngine:
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
if (not self.scheduler_config.is_multi_step
|
||||
or not seq_group_metadata_list):
|
||||
return False
|
||||
|
||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||
# dynamic scheduling when doing multi-step decoding.
|
||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||
if any([
|
||||
seq_group.state.remaining_steps != ref_remaining_steps
|
||||
for seq_group in seq_group_metadata_list[1:]
|
||||
]):
|
||||
raise AssertionError("All running sequence groups should "
|
||||
"have the same remaining steps.")
|
||||
|
||||
return ref_remaining_steps > 0
|
||||
return False
|
||||
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
@@ -1497,13 +1358,6 @@ class LLMEngine:
|
||||
|
||||
def _get_last_sampled_token_ids(
|
||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||
cached_last_output = self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and self.parallel_config.pipeline_parallel_size > 1
|
||||
and cached_last_output is not None
|
||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||
return cached_last_output.sampled_token_ids_cpu
|
||||
return None
|
||||
|
||||
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
||||
|
||||
Reference in New Issue
Block a user