[core] Multi Step Scheduling (#7000)
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
This commit is contained in:
@@ -9,7 +9,6 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Union, cast)
|
||||
|
||||
import msgspec
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
||||
@@ -1082,7 +1081,10 @@ class SamplerOutput(
|
||||
|
||||
# On-device tensor containing the sampled token ids.
|
||||
sampled_token_ids: Optional[torch.Tensor] = None
|
||||
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
||||
# CPU tensor containing the sampled token ids. Used during multi-step to
|
||||
# return the sampled token ids from last rank to AsyncLLMEngine to be
|
||||
# 'broadcasted' to all other PP ranks for next step.
|
||||
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
@@ -1257,9 +1259,7 @@ class ExecuteModelRequest(
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
first_seq_group = self.seq_group_metadata_list[0]
|
||||
assert first_seq_group.state is not None
|
||||
num_steps = first_seq_group.state.num_steps
|
||||
current_step = first_seq_group.state.current_step
|
||||
return num_steps - current_step == 1
|
||||
return first_seq_group.state.remaining_steps == 1
|
||||
|
||||
@property
|
||||
def current_step(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user