[core] [3/N] multi-step args and sequence.py (#7452)

This commit is contained in:
William Lin
2024-08-14 12:32:45 -07:00
committed by GitHub
parent 3f674a49b5
commit 2ecf7b1757
4 changed files with 100 additions and 5 deletions

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union, cast)
import numpy
import torch
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
@@ -489,6 +490,19 @@ class Sequence:
f"num_blocks={self.n_blocks}, ")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# for multi-step decoding
num_steps: int = 1
current_step: int = 0
@property
def remaining_steps(self) -> int:
return self.num_steps - self.current_step
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
@@ -534,6 +548,7 @@ class SequenceGroup:
time_in_queue=None)
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request
@@ -588,6 +603,10 @@ class SequenceGroup:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0
def init_multi_step(self, num_scheduler_steps: int) -> None:
self.state.num_steps = num_scheduler_steps
self.state.current_step = 0
def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
@@ -756,6 +775,7 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
@@ -781,6 +801,7 @@ class SequenceGroupMetadata:
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
@@ -796,6 +817,7 @@ class SequenceGroupMetadata:
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size
@@ -834,6 +856,10 @@ class SequenceGroupMetadata:
assert self._token_chunk_size is not None
return self._token_chunk_size
def finish_step(self) -> None:
assert self.state.current_step < self.state.num_steps
self.state.current_step += 1
class SequenceOutput:
"""The model output associated with a sequence.
@@ -971,6 +997,7 @@ class SamplerOutput:
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@@ -1112,6 +1139,33 @@ class ExecuteModelRequest:
num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
@property
def is_first_multi_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
return first_seq_group.state.current_step == 0
@property
def is_last_step(self) -> bool:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1
@property
def current_step(self) -> int:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
return self.seq_group_metadata_list[0].state.current_step
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
@@ -1127,4 +1181,6 @@ class ExecuteModelRequest:
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids)
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None)