[core] Multi Step Scheduling (#7000)

Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
This commit is contained in:
William Lin
2024-08-19 13:52:13 -07:00
committed by GitHub
parent dad961ef5c
commit 47b65a5508
13 changed files with 1004 additions and 34 deletions

View File

@@ -9,7 +9,6 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
import msgspec
import numpy
import torch
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
@@ -1082,7 +1081,10 @@ class SamplerOutput(
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
@@ -1257,9 +1259,7 @@ class ExecuteModelRequest(
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1
return first_seq_group.state.remaining_steps == 1
@property
def current_step(self) -> int: