[core] Multi Step Scheduling (#7000)
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@@ -27,7 +29,8 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@@ -249,9 +252,25 @@ class RequestTracker:
|
||||
return not self._new_requests.empty()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutputState:
|
||||
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
|
||||
last_output: Optional[SamplerOutput] = None
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
pipeline_parallel_size = \
|
||||
self.parallel_config.pipeline_parallel_size
|
||||
self.cached_scheduler_outputs = [
|
||||
SchedulerOutputState() for _ in range(pipeline_parallel_size)
|
||||
]
|
||||
|
||||
async def step_async(
|
||||
self, virtual_engine: int
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
@@ -264,13 +283,39 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
virtual_engine].schedule()
|
||||
# these are cached outputs from previous iterations. None if on first
|
||||
# iteration
|
||||
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
|
||||
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
|
||||
scheduler_outputs = cached_outputs.scheduler_outputs
|
||||
# skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler[
|
||||
virtual_engine].schedule()
|
||||
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and scheduler_outputs.num_lookahead_slots > 0):
|
||||
# cache the scheduler outputs for the next iteration if we have
|
||||
# lookahead slots
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs)
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Check if we have a cached last_output from the previous iteration.
|
||||
# For supporting PP this is probably the best way to pass the
|
||||
# sampled_token_ids, as a separate broadcast over all the PP stages
|
||||
# will cause one virtual engine's microbatch to block the pipeline.
|
||||
last_sampled_token_ids = \
|
||||
self._get_last_sampled_token_ids(virtual_engine)
|
||||
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
@@ -279,15 +324,35 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
virtual_engine=virtual_engine,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||
running_queue_size=scheduler_outputs.running_queue_size,
|
||||
finished_requests_ids=finished_requests_ids)
|
||||
finished_requests_ids=finished_requests_ids,
|
||||
# We use ExecuteModelRequest to pass the last sampled_token_ids
|
||||
# to each of the non-last PP stages for in-place prepare_input.
|
||||
last_sampled_token_ids=last_sampled_token_ids)
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
execute_model_req)
|
||||
# we need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self._update_cached_scheduler_output(virtual_engine, output)
|
||||
else:
|
||||
output = []
|
||||
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
for seq_group in seq_group_metadata_list:
|
||||
seq_group.finish_step()
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# clear the cache if we have finished all the steps
|
||||
if self.scheduler_config.is_multi_step:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine] = SchedulerOutputState()
|
||||
request_outputs = self._process_model_outputs(
|
||||
output, scheduler_outputs.scheduled_seq_groups,
|
||||
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
|
||||
else:
|
||||
request_outputs = []
|
||||
|
||||
# Log stats.
|
||||
self.do_log_stats(scheduler_outputs, output)
|
||||
@@ -297,6 +362,60 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
if (not self.scheduler_config.is_multi_step
|
||||
or not seq_group_metadata_list):
|
||||
return False
|
||||
|
||||
# TODO(will) this is a sanity check for nowto make sure that all the
|
||||
# seqs are on the same steps. Eventually we will want to do some sort of
|
||||
# dynamic scheduling when doing multi-step decoding.
|
||||
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
|
||||
if any([
|
||||
seq_group.state.remaining_steps != ref_remaining_steps
|
||||
for seq_group in seq_group_metadata_list[1:]
|
||||
]):
|
||||
raise AssertionError(("All running sequence groups should "
|
||||
"have the same remaining steps."))
|
||||
|
||||
return ref_remaining_steps > 0
|
||||
|
||||
def _cache_scheduler_outputs_for_multi_step(
|
||||
self, virtual_engine: int,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
scheduler_outputs: SchedulerOutputs) -> None:
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
|
||||
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
|
||||
scheduler_outputs
|
||||
self.cached_scheduler_outputs[virtual_engine].last_output = None
|
||||
|
||||
def _get_last_sampled_token_ids(
|
||||
self, virtual_engine: int) -> Optional[torch.Tensor]:
|
||||
cached_last_output = self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output
|
||||
if (self.scheduler_config.is_multi_step
|
||||
and self.parallel_config.pipeline_parallel_size > 1
|
||||
and cached_last_output is not None
|
||||
and cached_last_output.sampled_token_ids_cpu is not None):
|
||||
return cached_last_output.sampled_token_ids_cpu
|
||||
return None
|
||||
|
||||
def _update_cached_scheduler_output(
|
||||
self, virtual_engine: int,
|
||||
output: List[Optional[SamplerOutput]]) -> None:
|
||||
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
|
||||
and output[0] is not None):
|
||||
last_output = output[-1]
|
||||
assert last_output is not None
|
||||
assert last_output.sampled_token_ids_cpu is not None
|
||||
assert last_output.sampled_token_ids is None
|
||||
assert last_output.sampled_token_probs is None
|
||||
self.cached_scheduler_outputs[
|
||||
virtual_engine].last_output = last_output
|
||||
|
||||
async def stop_remote_worker_execution_loop_async(self) -> None:
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
Reference in New Issue
Block a user