[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
This commit is contained in:
386
vllm/sequence.py
386
vllm/sequence.py
@@ -4,10 +4,11 @@ import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||
Union, cast)
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Union, cast)
|
||||
|
||||
import msgspec
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
@@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
from vllm.multimodal.base import MultiModalDataDict
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
|
||||
# We use dataclass for now because it is used for
|
||||
# openai server output, and msgspec is not serializable.
|
||||
# TODO(sang): Fix it.
|
||||
@dataclass
|
||||
class Logprob:
|
||||
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
||||
@@ -112,7 +118,23 @@ class RequestMetrics:
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
|
||||
class SequenceData:
|
||||
class SequenceDataDelta(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Delta SequenceData to send to workers per step."""
|
||||
# A new token to be appended to existing SequenceData.
|
||||
new_output_token_ids: List[int]
|
||||
# Overwriting existing `cumulative_logprob`
|
||||
new_cumulative_logprob: float
|
||||
# Overwriting existing `num_computed_tokens`.
|
||||
new_num_computed_tokens: int
|
||||
# Overwriting existing `stage`.
|
||||
new_stage: SequenceStage
|
||||
|
||||
|
||||
class SequenceData(msgspec.Struct,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Data associated with a sequence.
|
||||
|
||||
Args:
|
||||
@@ -125,40 +147,57 @@ class SequenceData:
|
||||
output_token_ids: The token IDs of the output.
|
||||
cumulative_logprob: The cumulative log probability of the output.
|
||||
"""
|
||||
# NOTE: we cannot use Union[List, array] because msgspec cannot support
|
||||
# union of 2 list types.
|
||||
_prompt_token_ids: array
|
||||
_output_token_ids: array = msgspec.field(
|
||||
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
output_token_ids: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
self._prompt_token_ids = array('l', prompt_token_ids)
|
||||
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
|
||||
self._output_token_ids = array(
|
||||
'l', output_token_ids if output_token_ids is not None else [])
|
||||
### The below fields should not be passed as an argument ###
|
||||
_cumulative_logprob: float = 0.0
|
||||
_prompt_token_ids_tuple: Tuple[int,
|
||||
...] = msgspec.field(default_factory=tuple)
|
||||
# The number of tokens that are computed (that run against the model).
|
||||
_num_computed_tokens: int = 0
|
||||
_stage: SequenceStage = SequenceStage.PREFILL
|
||||
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
self.cumulative_logprob = 0.0
|
||||
# The number of tokens that are computed (that run against the model).
|
||||
self._num_computed_tokens = 0
|
||||
self._stage: SequenceStage = SequenceStage.PREFILL
|
||||
# It is used to get delta input. It is reset when `get_delta_and_reset`
|
||||
# is called.
|
||||
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
assert self._output_token_ids.typecode == "l"
|
||||
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
|
||||
self._prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
def _update_cached_all_tokens(self):
|
||||
assert isinstance(self._prompt_token_ids, array)
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
|
||||
self._output_token_ids)
|
||||
|
||||
@property
|
||||
def cumulative_logprob(self) -> float:
|
||||
return self._cumulative_logprob
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> Tuple[int, ...]:
|
||||
return self._prompt_token_ids_tuple
|
||||
|
||||
@prompt_token_ids.setter
|
||||
def prompt_token_ids(self, new_prompt_token_ids) -> None:
|
||||
self._prompt_token_ids = array('l', new_prompt_token_ids)
|
||||
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def prompt_token_ids_array(self) -> array:
|
||||
"""Return the prompt token ids in array type.
|
||||
|
||||
Note that the array is in "I" type, and it is not compatible
|
||||
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
|
||||
"""
|
||||
return self._prompt_token_ids
|
||||
|
||||
@property
|
||||
@@ -166,18 +205,26 @@ class SequenceData:
|
||||
return tuple(self._output_token_ids)
|
||||
|
||||
@output_token_ids.setter
|
||||
def output_token_ids(self, new_output_token_ids) -> None:
|
||||
self._output_token_ids = array('l', new_output_token_ids)
|
||||
def output_token_ids(self, new_output_token_ids: List[int]) -> None:
|
||||
self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
new_output_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
@property
|
||||
def output_token_ids_array(self) -> array:
|
||||
"""Return the prompt token ids in array type.
|
||||
|
||||
Note that the array is in "I" type, and it is not compatible
|
||||
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
|
||||
"""
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
return self._output_token_ids
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._new_appended_tokens.append(token_id)
|
||||
self._cached_all_token_ids.append(token_id)
|
||||
self.cumulative_logprob += logprob
|
||||
self._cumulative_logprob += logprob
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self._output_token_ids) + len(self._prompt_token_ids)
|
||||
@@ -222,6 +269,7 @@ class SequenceData:
|
||||
"""
|
||||
self._num_computed_tokens = 0
|
||||
self._stage = SequenceStage.PREFILL
|
||||
self._new_appended_tokens = []
|
||||
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
"""Return the number of prefill tokens that are not computed."""
|
||||
@@ -241,6 +289,21 @@ class SequenceData:
|
||||
def get_output_token_ids(self) -> Tuple[int, ...]:
|
||||
return self.output_token_ids
|
||||
|
||||
def get_delta_and_reset(self) -> SequenceDataDelta:
|
||||
delta = SequenceDataDelta(self._new_appended_tokens,
|
||||
self._cumulative_logprob,
|
||||
self.get_num_computed_tokens(), self.stage)
|
||||
# Reset delta state.
|
||||
self._new_appended_tokens = []
|
||||
return delta
|
||||
|
||||
def apply_delta(self, delta: SequenceDataDelta):
|
||||
self._num_computed_tokens = delta.new_num_computed_tokens
|
||||
self._cumulative_logprob = delta.new_cumulative_logprob
|
||||
self._stage = delta.new_stage
|
||||
self._output_token_ids.extend(delta.new_output_token_ids)
|
||||
self._cached_all_token_ids.extend(delta.new_output_token_ids)
|
||||
|
||||
@property
|
||||
def stage(self) -> SequenceStage:
|
||||
return self._stage
|
||||
@@ -248,8 +311,9 @@ class SequenceData:
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt_token_ids={self._prompt_token_ids}, "
|
||||
f"output_token_ids={self._output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob})")
|
||||
f"output_token_ids={self.output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||
f"get_num_computed_tokens={self.get_num_computed_tokens()}")
|
||||
|
||||
|
||||
class Sequence:
|
||||
@@ -325,7 +389,8 @@ class Sequence:
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
|
||||
self.data = SequenceData(self.prompt_token_ids)
|
||||
self.data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
@@ -490,8 +555,8 @@ class Sequence:
|
||||
f"num_blocks={self.n_blocks}, ")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupState:
|
||||
class SequenceGroupState(msgspec.Struct,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Mutable state tied to a specific sequence group"""
|
||||
|
||||
# for multi-step decoding
|
||||
@@ -647,14 +712,19 @@ class SequenceGroup:
|
||||
if self.sampling_params and self.sampling_params.use_beam_search:
|
||||
# For beam search, maximally there will always be `best_of` beam
|
||||
# candidates running in the future.
|
||||
return self.sampling_params.best_of
|
||||
best_of = self.sampling_params.best_of
|
||||
assert isinstance(best_of, int)
|
||||
return best_of
|
||||
else:
|
||||
if (self.sampling_params
|
||||
and self.sampling_params.best_of > self.num_seqs()):
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
return self.sampling_params.best_of
|
||||
if self.sampling_params:
|
||||
best_of = self.sampling_params.best_of
|
||||
assert isinstance(best_of, int)
|
||||
if best_of > self.num_seqs():
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `best_of` sequences
|
||||
# running.
|
||||
return best_of
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# that are not finished yet.
|
||||
return self.num_unfinished_seqs()
|
||||
@@ -757,7 +827,32 @@ class SequenceGroup:
|
||||
f"num_seqs={len(self.seqs)})")
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
class SequenceGroupMetadataDelta(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Delta of SequenceGroupMetadata.
|
||||
|
||||
After sending the first SequenceGroupMetadata, vLLM scheduler
|
||||
only sends delta to reduce the data payload size.
|
||||
"""
|
||||
seq_data_delta: Dict[int, SequenceDataDelta]
|
||||
request_id: str
|
||||
block_tables: Dict[int, List[int]]
|
||||
is_prompt: bool
|
||||
do_sample: bool = True
|
||||
token_chunk_size: Optional[int] = None
|
||||
computed_block_nums: Optional[List[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
|
||||
|
||||
class SequenceGroupMetadata(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
|
||||
|
||||
Args:
|
||||
@@ -789,52 +884,39 @@ class SequenceGroupMetadata:
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
is_prompt: bool,
|
||||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]],
|
||||
do_sample: bool = True,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
token_chunk_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
state: Optional[SequenceGroupState] = None,
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
encoder_seq_data: Optional[SequenceData] = None,
|
||||
cross_block_table: Optional[List[int]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
self.seq_data = seq_data
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
self.pooling_params = pooling_params
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
self.encoder_seq_data = encoder_seq_data
|
||||
self.cross_block_table = cross_block_table
|
||||
self._token_chunk_size = token_chunk_size
|
||||
self.do_sample = do_sample
|
||||
request_id: str
|
||||
is_prompt: bool
|
||||
seq_data: Dict[int, SequenceData]
|
||||
sampling_params: SamplingParams
|
||||
block_tables: Dict[int, List[int]]
|
||||
do_sample: bool = True
|
||||
pooling_params: Optional[PoolingParams] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
computed_block_nums: Optional[List[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||
# doesn't allow to have union of 2 different dicts.
|
||||
multi_modal_data: Optional[Any] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
cross_block_table: Optional[List[int]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
token_chunk_size: Optional[int] = None
|
||||
|
||||
# The number of speculative tokens adopted in this request.
|
||||
# None means specuative decoding is not used.
|
||||
# Zero means speculative decoding is disabled for some reasons.
|
||||
# TODO: We should maintain this states out of the sequence group.
|
||||
self.num_speculative_tokens = None
|
||||
### Stateful fields that are lazily defined. ###
|
||||
# The number of speculative tokens adopted in this request.
|
||||
# None means specuative decoding is not used.
|
||||
# Zero means speculative decoding is disabled for some reasons.
|
||||
# TODO: We should maintain this states out of the sequence group.
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
|
||||
if seq_data is not None and self._token_chunk_size is None:
|
||||
if is_prompt:
|
||||
self._token_chunk_size = next(iter(
|
||||
seq_data.values())).get_len()
|
||||
def __post_init__(self):
|
||||
if self.seq_data is not None and self.token_chunk_size is None:
|
||||
if self.is_prompt:
|
||||
self.token_chunk_size = next(iter(
|
||||
self.seq_data.values())).get_len()
|
||||
else:
|
||||
self._token_chunk_size = 1
|
||||
self.token_chunk_size = 1
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
@@ -850,18 +932,26 @@ class SequenceGroupMetadata:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def token_chunk_size(self) -> int:
|
||||
"""Return the number of tokens to be processed (chunk size)."""
|
||||
assert self._token_chunk_size is not None
|
||||
return self._token_chunk_size
|
||||
def apply_delta(self,
|
||||
sequence_group_metadata_delta: SequenceGroupMetadataDelta):
|
||||
for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
|
||||
self.seq_data[id].apply_delta(delta)
|
||||
assert self.request_id == sequence_group_metadata_delta.request_id
|
||||
self.block_tables = sequence_group_metadata_delta.block_tables
|
||||
self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
|
||||
self.do_sample = sequence_group_metadata_delta.do_sample
|
||||
self.is_prompt = sequence_group_metadata_delta.is_prompt
|
||||
|
||||
def finish_step(self) -> None:
|
||||
assert self.state is not None
|
||||
assert self.state.current_step < self.state.num_steps
|
||||
self.state.current_step += 1
|
||||
|
||||
|
||||
class SequenceOutput:
|
||||
class SequenceOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
@@ -871,16 +961,9 @@ class SequenceOutput:
|
||||
logprobs: The logprobs of the output token.
|
||||
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, Logprob],
|
||||
) -> None:
|
||||
self.parent_seq_id = parent_seq_id
|
||||
self.output_token = output_token
|
||||
self.logprobs = logprobs
|
||||
parent_seq_id: int
|
||||
output_token: int
|
||||
logprobs: Dict[int, Logprob]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||
@@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
||||
class CompletionSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
"""The model output associated with a completion sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: List[SequenceOutput],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
# Prompt logprob for each prompt query token.
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
samples: List[SequenceOutput]
|
||||
# Prompt logprob for each prompt query token.
|
||||
prompt_logprobs: Optional[PromptLogprobs]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
|
||||
@@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput):
|
||||
and self.prompt_logprobs == other.prompt_logprobs)
|
||||
|
||||
|
||||
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
||||
class EmbeddingSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
):
|
||||
"""The model output associated with an embedding sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: List[float],
|
||||
) -> None:
|
||||
self.embeddings = embeddings
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
embeddings: List[int]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"EmbeddingSequenceGroupOutput("
|
||||
@@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
|
||||
return self.embeddings == other.embeddings
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntermediateTensors:
|
||||
class IntermediateTensors(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""For all pipeline stages except the last, we need to return the hidden
|
||||
states and residuals to be sent to the next stage. This data structure
|
||||
contains the hidden states and residuals for a request.
|
||||
@@ -978,8 +1061,10 @@ class IntermediateTensors:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
class SamplerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""For each sequence group, we generate a list of SequenceOutput object,
|
||||
each of which contains one possible candidate for the next token.
|
||||
|
||||
@@ -1000,7 +1085,7 @@ class SamplerOutput:
|
||||
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
@@ -1039,12 +1124,14 @@ class SamplerOutput:
|
||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolerOutput:
|
||||
class PoolerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The output from a pooling operation in the embedding model."""
|
||||
outputs: List[EmbeddingSequenceGroupOutput]
|
||||
|
||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self.outputs[idx]
|
||||
@@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids(
|
||||
return seq_ids, request_id_seq_ids_mapping
|
||||
|
||||
|
||||
class HiddenStates:
|
||||
class HiddenStates(msgspec.Struct, array_like=True,
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Hidden states corresponding to in-progress sequences.
|
||||
Used in speculative decoding to pass hidden states from
|
||||
the target model to the proposer model in the subsequent step.
|
||||
@@ -1091,42 +1179,53 @@ class HiddenStates:
|
||||
seq_ids are the sequence ids of each entry of the batch
|
||||
dimension of the hidden_states tensor"""
|
||||
|
||||
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor):
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
|
||||
self.hidden_states: torch.Tensor = hidden_states
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
hidden_states: torch.Tensor
|
||||
_seq_ids: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
||||
|
||||
@property
|
||||
def seq_ids(self) -> List[int]:
|
||||
return self._seq_ids
|
||||
|
||||
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Update hidden states from target model invocation."""
|
||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
||||
|
||||
def prune(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
"""Prune to provided list of sequence ids."""
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
if seq_ids != self.seq_ids:
|
||||
if seq_ids != self._seq_ids:
|
||||
# Batch contents changed - prune removed sequences.
|
||||
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||
self.hidden_states = self.hidden_states[index]
|
||||
self.seq_ids = seq_ids
|
||||
self._seq_ids = seq_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelRequest:
|
||||
class ExecuteModelRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""The model execution request, containing CPU metadata only. The LLM
|
||||
engine should create an instance of this class for each request batch."""
|
||||
# The sequence group metadata list.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]]
|
||||
# Blocks to swap in. List of CPU -> GPU block number.
|
||||
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
|
||||
blocks_to_swap_in: List[Tuple[int,
|
||||
int]] = msgspec.field(default_factory=list)
|
||||
# Blocks to swap out. List of GPU -> CPU block number.
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
|
||||
blocks_to_swap_out: List[Tuple[int,
|
||||
int]] = msgspec.field(default_factory=list)
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
||||
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
|
||||
# Virtual engine ID for pipeline parallel.
|
||||
virtual_engine: int = 0
|
||||
# The number of slots for lookahead decoding.
|
||||
@@ -1138,7 +1237,7 @@ class ExecuteModelRequest:
|
||||
# The number of forward steps to run.
|
||||
num_steps: int = 1
|
||||
# Finished request ids since last step.
|
||||
finished_requests_ids: List[str] = field(default_factory=list)
|
||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -1148,6 +1247,7 @@ class ExecuteModelRequest:
|
||||
# steps
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
first_seq_group = self.seq_group_metadata_list[0]
|
||||
assert first_seq_group.state is not None
|
||||
return first_seq_group.state.current_step == 0
|
||||
|
||||
@property
|
||||
@@ -1156,6 +1256,7 @@ class ExecuteModelRequest:
|
||||
# steps
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
first_seq_group = self.seq_group_metadata_list[0]
|
||||
assert first_seq_group.state is not None
|
||||
num_steps = first_seq_group.state.num_steps
|
||||
current_step = first_seq_group.state.current_step
|
||||
return num_steps - current_step == 1
|
||||
@@ -1165,10 +1266,13 @@ class ExecuteModelRequest:
|
||||
# TODO(will) make this be able to handle batches with variable number of
|
||||
# steps
|
||||
assert len(self.seq_group_metadata_list) > 0
|
||||
return self.seq_group_metadata_list[0].state.current_step
|
||||
state = self.seq_group_metadata_list[0].state
|
||||
assert state is not None
|
||||
return state.current_step
|
||||
|
||||
def clone(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]]
|
||||
) -> "ExecuteModelRequest":
|
||||
"""Clone the request with a new sequence group metadata list."""
|
||||
return ExecuteModelRequest(
|
||||
|
||||
Reference in New Issue
Block a user