[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)

This commit is contained in:
SangBin Cho
2024-08-18 17:57:20 -07:00
committed by GitHub
parent 200a2ffa6b
commit ff7ec82c4d
36 changed files with 722 additions and 346 deletions

View File

@@ -4,10 +4,11 @@ import enum
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union, cast)
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
import msgspec
import numpy
import torch
@@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal import MultiModalDataDict
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs and token ranks.
@@ -112,7 +118,23 @@ class RequestMetrics:
model_execute_time: Optional[float] = None
class SequenceData:
class SequenceDataDelta(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
new_output_token_ids: List[int]
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob: float
# Overwriting existing `num_computed_tokens`.
new_num_computed_tokens: int
# Overwriting existing `stage`.
new_stage: SequenceStage
class SequenceData(msgspec.Struct,
omit_defaults=True): # type: ignore[call-arg]
"""Data associated with a sequence.
Args:
@@ -125,40 +147,57 @@ class SequenceData:
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
# NOTE: we cannot use Union[List, array] because msgspec cannot support
# union of 2 list types.
_prompt_token_ids: array
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
def __init__(
self,
prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None,
) -> None:
self._prompt_token_ids = array('l', prompt_token_ids)
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self._output_token_ids = array(
'l', output_token_ids if output_token_ids is not None else [])
### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Tuple[int,
...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model).
self._num_computed_tokens = 0
self._stage: SequenceStage = SequenceStage.PREFILL
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
self._prompt_token_ids)
self._update_cached_all_tokens()
def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
assert isinstance(self._output_token_ids, array)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._output_token_ids)
@property
def cumulative_logprob(self) -> float:
return self._cumulative_logprob
@property
def prompt_token_ids(self) -> Tuple[int, ...]:
return self._prompt_token_ids_tuple
@prompt_token_ids.setter
def prompt_token_ids(self, new_prompt_token_ids) -> None:
self._prompt_token_ids = array('l', new_prompt_token_ids)
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
self._update_cached_all_tokens()
raise NotImplementedError
@property
def prompt_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
return self._prompt_token_ids
@property
@@ -166,18 +205,26 @@ class SequenceData:
return tuple(self._output_token_ids)
@output_token_ids.setter
def output_token_ids(self, new_output_token_ids) -> None:
self._output_token_ids = array('l', new_output_token_ids)
def output_token_ids(self, new_output_token_ids: List[int]) -> None:
self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
new_output_token_ids)
self._update_cached_all_tokens()
@property
def output_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
assert isinstance(self._output_token_ids, array)
return self._output_token_ids
def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self.cumulative_logprob += logprob
self._cumulative_logprob += logprob
def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
@@ -222,6 +269,7 @@ class SequenceData:
"""
self._num_computed_tokens = 0
self._stage = SequenceStage.PREFILL
self._new_appended_tokens = []
def get_num_uncomputed_tokens(self) -> int:
"""Return the number of prefill tokens that are not computed."""
@@ -241,6 +289,21 @@ class SequenceData:
def get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids
def get_delta_and_reset(self) -> SequenceDataDelta:
delta = SequenceDataDelta(self._new_appended_tokens,
self._cumulative_logprob,
self.get_num_computed_tokens(), self.stage)
# Reset delta state.
self._new_appended_tokens = []
return delta
def apply_delta(self, delta: SequenceDataDelta):
self._num_computed_tokens = delta.new_num_computed_tokens
self._cumulative_logprob = delta.new_cumulative_logprob
self._stage = delta.new_stage
self._output_token_ids.extend(delta.new_output_token_ids)
self._cached_all_token_ids.extend(delta.new_output_token_ids)
@property
def stage(self) -> SequenceStage:
return self._stage
@@ -248,8 +311,9 @@ class SequenceData:
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
f"output_token_ids={self._output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob})")
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()}")
class Sequence:
@@ -325,7 +389,8 @@ class Sequence:
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData(self.prompt_token_ids)
self.data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@@ -490,8 +555,8 @@ class Sequence:
f"num_blocks={self.n_blocks}, ")
@dataclass
class SequenceGroupState:
class SequenceGroupState(msgspec.Struct,
omit_defaults=True): # type: ignore[call-arg]
"""Mutable state tied to a specific sequence group"""
# for multi-step decoding
@@ -647,14 +712,19 @@ class SequenceGroup:
if self.sampling_params and self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
return best_of
else:
if (self.sampling_params
and self.sampling_params.best_of > self.num_seqs()):
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
if best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
@@ -757,7 +827,32 @@ class SequenceGroup:
f"num_seqs={len(self.seqs)})")
class SequenceGroupMetadata:
class SequenceGroupMetadataDelta(
msgspec.Struct,
tag=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Delta of SequenceGroupMetadata.
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
"""
seq_data_delta: Dict[int, SequenceDataDelta]
request_id: str
block_tables: Dict[int, List[int]]
is_prompt: bool
do_sample: bool = True
token_chunk_size: Optional[int] = None
computed_block_nums: Optional[List[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
class SequenceGroupMetadata(
msgspec.Struct,
tag=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
Args:
@@ -789,52 +884,39 @@ class SequenceGroupMetadata:
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
self,
request_id: str,
is_prompt: bool,
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
do_sample: bool = True,
pooling_params: Optional[PoolingParams] = None,
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
self.seq_data = seq_data
self.sampling_params = sampling_params
self.block_tables = block_tables
self.pooling_params = pooling_params
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample
request_id: str
is_prompt: bool
seq_data: Dict[int, SequenceData]
sampling_params: SamplingParams
block_tables: Dict[int, List[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
lora_request: Optional[LoRARequest] = None
computed_block_nums: Optional[List[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
multi_modal_data: Optional[Any] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None
### Stateful fields that are lazily defined. ###
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
num_speculative_tokens: Optional[int] = None
if seq_data is not None and self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = next(iter(
seq_data.values())).get_len()
def __post_init__(self):
if self.seq_data is not None and self.token_chunk_size is None:
if self.is_prompt:
self.token_chunk_size = next(iter(
self.seq_data.values())).get_len()
else:
self._token_chunk_size = 1
self.token_chunk_size = 1
@property
def lora_int_id(self) -> int:
@@ -850,18 +932,26 @@ class SequenceGroupMetadata:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
@property
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""
assert self._token_chunk_size is not None
return self._token_chunk_size
def apply_delta(self,
sequence_group_metadata_delta: SequenceGroupMetadataDelta):
for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
self.seq_data[id].apply_delta(delta)
assert self.request_id == sequence_group_metadata_delta.request_id
self.block_tables = sequence_group_metadata_delta.block_tables
self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size
self.do_sample = sequence_group_metadata_delta.do_sample
self.is_prompt = sequence_group_metadata_delta.is_prompt
def finish_step(self) -> None:
assert self.state is not None
assert self.state.current_step < self.state.num_steps
self.state.current_step += 1
class SequenceOutput:
class SequenceOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The model output associated with a sequence.
Args:
@@ -871,16 +961,9 @@ class SequenceOutput:
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__(
self,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, Logprob],
) -> None:
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
parent_seq_id: int
output_token: int
logprobs: Dict[int, Logprob]
def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
@@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC):
pass
class CompletionSequenceGroupOutput(SequenceGroupOutput):
class CompletionSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
__metaclass__ = SequenceGroupOutput
"""The model output associated with a completion sequence group."""
def __init__(
self,
samples: List[SequenceOutput],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
# Prompt logprob for each prompt query token.
self.prompt_logprobs = prompt_logprobs
samples: List[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
def __repr__(self) -> str:
return (f"CompletionSequenceGroupOutput(samples={self.samples}, "
@@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput):
and self.prompt_logprobs == other.prompt_logprobs)
class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
class EmbeddingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
):
"""The model output associated with an embedding sequence group."""
def __init__(
self,
embeddings: List[float],
) -> None:
self.embeddings = embeddings
__metaclass__ = SequenceGroupOutput
embeddings: List[int]
def __repr__(self) -> str:
return (f"EmbeddingSequenceGroupOutput("
@@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings
@dataclass
class IntermediateTensors:
class IntermediateTensors(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
@@ -978,8 +1061,10 @@ class IntermediateTensors:
return f"IntermediateTensors(tensors={self.tensors})"
@dataclass
class SamplerOutput:
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
@@ -1000,7 +1085,7 @@ class SamplerOutput:
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
@@ -1039,12 +1124,14 @@ class SamplerOutput:
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
@dataclass
class PoolerOutput:
class PoolerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the embedding model."""
outputs: List[EmbeddingSequenceGroupOutput]
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
@@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids(
return seq_ids, request_id_seq_ids_mapping
class HiddenStates:
class HiddenStates(msgspec.Struct, array_like=True,
omit_defaults=True): # type: ignore[call-arg]
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step.
@@ -1091,42 +1179,53 @@ class HiddenStates:
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""
def __init__(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor):
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids: List[int] = get_all_seq_ids(seq_group_metadata_list)
self.hidden_states: torch.Tensor = hidden_states
seq_group_metadata_list: List[SequenceGroupMetadata]
hidden_states: torch.Tensor
_seq_ids: List[int] = msgspec.field(default_factory=list)
def __post_init__(self):
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
@property
def seq_ids(self) -> List[int]:
return self._seq_ids
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
hidden_states: torch.Tensor) -> None:
"""Update hidden states from target model invocation."""
assert len(seq_group_metadata_list) == len(hidden_states)
self.seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids."""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
if seq_ids != self.seq_ids:
if seq_ids != self._seq_ids:
# Batch contents changed - prune removed sequences.
index = [self.seq_ids.index(seq_id) for seq_id in seq_ids]
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
self.hidden_states = self.hidden_states[index]
self.seq_ids = seq_ids
self._seq_ids = seq_ids
@dataclass
class ExecuteModelRequest:
class ExecuteModelRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
blocks_to_swap_in: List[Tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
blocks_to_swap_out: List[Tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding.
@@ -1138,7 +1237,7 @@ class ExecuteModelRequest:
# The number of forward steps to run.
num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
@@ -1148,6 +1247,7 @@ class ExecuteModelRequest:
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
return first_seq_group.state.current_step == 0
@property
@@ -1156,6 +1256,7 @@ class ExecuteModelRequest:
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
assert first_seq_group.state is not None
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1
@@ -1165,10 +1266,13 @@ class ExecuteModelRequest:
# TODO(will) make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
return self.seq_group_metadata_list[0].state.current_step
state = self.seq_group_metadata_list[0].state
assert state is not None
return state.current_step
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list."""
return ExecuteModelRequest(