Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
132
vllm/sequence.py
132
vllm/sequence.py
@@ -5,11 +5,11 @@ import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import reduce
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@@ -50,9 +50,9 @@ class Logprob:
|
||||
|
||||
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||
# sequence group doesn't require prompt logprob.
|
||||
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
|
||||
PromptLogprobs = list[Optional[dict[int, Logprob]]]
|
||||
# {token_id -> logprob} for each sequence group.
|
||||
SampleLogprobs = List[Dict[int, Logprob]]
|
||||
SampleLogprobs = list[dict[int, Logprob]]
|
||||
|
||||
|
||||
class SequenceStatus(enum.IntEnum):
|
||||
@@ -129,7 +129,7 @@ class SequenceDataDelta(
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""Delta SequenceData to send to workers per step."""
|
||||
# A new token to be appended to existing SequenceData.
|
||||
new_output_token_ids: List[int]
|
||||
new_output_token_ids: list[int]
|
||||
# Overwriting existing `cumulative_logprob`
|
||||
new_cumulative_logprob: float
|
||||
# Overwriting existing `num_computed_tokens`.
|
||||
@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct,
|
||||
output_token_ids: The token IDs of the output.
|
||||
cumulative_logprob: The cumulative log probability of the output.
|
||||
"""
|
||||
# NOTE: we cannot use Union[List, array] because msgspec cannot support
|
||||
# NOTE: we cannot use Union[list, array] because msgspec cannot support
|
||||
# union of 2 list types.
|
||||
_prompt_token_ids: array
|
||||
_output_token_ids: array = msgspec.field(
|
||||
@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct,
|
||||
|
||||
### The below fields should not be passed as an argument ###
|
||||
_cumulative_logprob: float = 0.0
|
||||
_prompt_token_ids_tuple: Tuple[int,
|
||||
_prompt_token_ids_tuple: tuple[int,
|
||||
...] = msgspec.field(default_factory=tuple)
|
||||
# The number of tokens that are computed (that run against the model).
|
||||
_num_computed_tokens: int = 0
|
||||
# The number of tokens with prefix cache hit.
|
||||
_num_cached_tokens: int = 0
|
||||
_stage: SequenceStage = SequenceStage.PREFILL
|
||||
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
|
||||
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
|
||||
|
||||
# It is used to get delta input. It is reset when `get_delta_and_reset`
|
||||
# is called.
|
||||
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
|
||||
_new_appended_tokens: list[int] = msgspec.field(default_factory=list)
|
||||
|
||||
# It is used to compute mrope_position_ids.
|
||||
_mrope_position_delta: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_prompt_token_counts(
|
||||
*token_counts: Tuple[int, int]) -> "SequenceData":
|
||||
*token_counts: tuple[int, int]) -> "SequenceData":
|
||||
"""
|
||||
Construct a :class:`SequenceData` instance by concatenating
|
||||
prompt token sequences.
|
||||
@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct,
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
assert self._output_token_ids.typecode == "l"
|
||||
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
|
||||
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
|
||||
self._prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
def _update_cached_all_tokens(self):
|
||||
assert isinstance(self._prompt_token_ids, array)
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
|
||||
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
|
||||
self._output_token_ids)
|
||||
|
||||
@property
|
||||
@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct,
|
||||
return self._cumulative_logprob
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> Tuple[int, ...]:
|
||||
def prompt_token_ids(self) -> tuple[int, ...]:
|
||||
return self._prompt_token_ids_tuple
|
||||
|
||||
@prompt_token_ids.setter
|
||||
@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct,
|
||||
return self._prompt_token_ids
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> Tuple[int, ...]:
|
||||
def output_token_ids(self) -> tuple[int, ...]:
|
||||
return tuple(self._output_token_ids)
|
||||
|
||||
@output_token_ids.setter
|
||||
@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct,
|
||||
def get_output_len(self) -> int:
|
||||
return len(self._output_token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
def get_token_ids(self) -> list[int]:
|
||||
return self._cached_all_token_ids
|
||||
|
||||
def get_prefix_token_ids(
|
||||
self, num_tokens: int
|
||||
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
|
||||
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
|
||||
"""Get prefix tokens, and make the return value hashable"""
|
||||
prompt_length = self.get_prompt_len()
|
||||
if num_tokens > prompt_length:
|
||||
@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct,
|
||||
return self._prompt_token_ids[-1]
|
||||
return self._output_token_ids[-1]
|
||||
|
||||
def get_prompt_token_ids(self) -> Tuple[int, ...]:
|
||||
def get_prompt_token_ids(self) -> tuple[int, ...]:
|
||||
return self.prompt_token_ids
|
||||
|
||||
def get_output_token_ids(self) -> Tuple[int, ...]:
|
||||
def get_output_token_ids(self) -> tuple[int, ...]:
|
||||
return self.output_token_ids
|
||||
|
||||
def get_delta_and_reset(self) -> SequenceDataDelta:
|
||||
@@ -432,7 +432,7 @@ class Sequence:
|
||||
self.prefix_offset = 0
|
||||
self.read_offset = 0
|
||||
# Input + output tokens
|
||||
self.tokens: Optional[List[str]] = None
|
||||
self.tokens: Optional[list[str]] = None
|
||||
|
||||
@property
|
||||
def n_blocks(self) -> int:
|
||||
@@ -443,7 +443,7 @@ class Sequence:
|
||||
return self.inputs.prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
def prompt_token_ids(self) -> list[int]:
|
||||
return self.inputs.prompt_token_ids
|
||||
|
||||
@property
|
||||
@@ -451,7 +451,7 @@ class Sequence:
|
||||
return self.inputs.prompt_embeds
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> List[int]:
|
||||
def token_type_ids(self) -> list[int]:
|
||||
return self.inputs.token_type_ids
|
||||
|
||||
@property
|
||||
@@ -463,7 +463,7 @@ class Sequence:
|
||||
return self.inputs.multi_modal_placeholders
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
def mm_processor_kwargs(self) -> dict[str, Any]:
|
||||
return self.inputs.mm_processor_kwargs
|
||||
|
||||
@property
|
||||
@@ -548,7 +548,7 @@ class Sequence:
|
||||
"""Reset the sequence states for recomputation."""
|
||||
self.data.reset_state_for_recompute()
|
||||
|
||||
def append_token_id(self, token_id: int, logprobs: Dict[int,
|
||||
def append_token_id(self, token_id: int, logprobs: dict[int,
|
||||
Logprob]) -> None:
|
||||
assert token_id in logprobs
|
||||
self.output_logprobs.append(logprobs)
|
||||
@@ -563,16 +563,16 @@ class Sequence:
|
||||
def get_output_len(self) -> int:
|
||||
return self.data.get_output_len()
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
def get_token_ids(self) -> list[int]:
|
||||
return self.data.get_token_ids()
|
||||
|
||||
def get_prompt_token_ids(self) -> Tuple[int, ...]:
|
||||
def get_prompt_token_ids(self) -> tuple[int, ...]:
|
||||
return self.data.get_prompt_token_ids()
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.data.get_last_token_id()
|
||||
|
||||
def get_output_token_ids(self) -> Tuple[int, ...]:
|
||||
def get_output_token_ids(self) -> tuple[int, ...]:
|
||||
return self.data.get_output_token_ids()
|
||||
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
@@ -644,7 +644,7 @@ class SequenceGroup:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
seqs: List[Sequence],
|
||||
seqs: list[Sequence],
|
||||
arrival_time: float,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@@ -686,7 +686,7 @@ class SequenceGroup:
|
||||
return self.first_seq.prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
def prompt_token_ids(self) -> list[int]:
|
||||
return self.first_seq.prompt_token_ids
|
||||
|
||||
@property
|
||||
@@ -698,7 +698,7 @@ class SequenceGroup:
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
|
||||
def encoder_prompt_token_ids(self) -> Optional[list[int]]:
|
||||
# There are either 0 or 1 encoder sequences
|
||||
# If one is present, its prompt token ids are
|
||||
# distinct from the decoder's.
|
||||
@@ -706,7 +706,7 @@ class SequenceGroup:
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> Optional[List[int]]:
|
||||
def token_type_ids(self) -> Optional[list[int]]:
|
||||
return self.first_seq.token_type_ids
|
||||
|
||||
@property
|
||||
@@ -726,7 +726,7 @@ class SequenceGroup:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
def mm_processor_kwargs(self) -> dict[str, Any]:
|
||||
if self.first_seq.multi_modal_data:
|
||||
return self.first_seq.mm_processor_kwargs
|
||||
elif self.encoder_seq is not None:
|
||||
@@ -823,7 +823,7 @@ class SequenceGroup:
|
||||
def get_seqs(
|
||||
self,
|
||||
status: Optional[SequenceStatus] = None,
|
||||
) -> List[Sequence]:
|
||||
) -> list[Sequence]:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
|
||||
@@ -838,7 +838,7 @@ class SequenceGroup:
|
||||
def get_encoder_seq(self) -> Optional[Sequence]:
|
||||
return self.encoder_seq
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
def get_finished_seqs(self) -> list[Sequence]:
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.first_seq.is_finished() else []
|
||||
|
||||
@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta(
|
||||
After sending the first SequenceGroupMetadata, vLLM scheduler
|
||||
only sends delta to reduce the data payload size.
|
||||
"""
|
||||
seq_data_delta: Dict[int, SequenceDataDelta]
|
||||
seq_data_delta: dict[int, SequenceDataDelta]
|
||||
request_id: str
|
||||
block_tables: Dict[int, List[int]]
|
||||
block_tables: dict[int, list[int]]
|
||||
is_prompt: bool
|
||||
do_sample: bool = True
|
||||
token_chunk_size: Optional[int] = None
|
||||
computed_block_nums: Optional[List[int]] = None
|
||||
computed_block_nums: Optional[list[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
|
||||
@@ -947,23 +947,23 @@ class SequenceGroupMetadata(
|
||||
|
||||
request_id: str
|
||||
is_prompt: bool
|
||||
seq_data: Dict[int, SequenceData]
|
||||
seq_data: dict[int, SequenceData]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
block_tables: Dict[int, List[int]]
|
||||
block_tables: dict[int, list[int]]
|
||||
do_sample: bool = True
|
||||
pooling_params: Optional[PoolingParams] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
computed_block_nums: Optional[List[int]] = None
|
||||
computed_block_nums: Optional[list[int]] = None
|
||||
state: Optional[SequenceGroupState] = msgspec.field(
|
||||
default_factory=lambda: SequenceGroupState())
|
||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||
# doesn't allow to have union of 2 different dicts.
|
||||
token_type_ids: Optional[List[int]] = None
|
||||
token_type_ids: Optional[list[int]] = None
|
||||
multi_modal_data: Optional[Any] = None
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None
|
||||
encoder_seq_data: Optional[SequenceData] = None
|
||||
cross_block_table: Optional[List[int]] = None
|
||||
cross_block_table: Optional[list[int]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
token_chunk_size: Optional[int] = None
|
||||
|
||||
@@ -1042,7 +1042,7 @@ class SequenceOutput(
|
||||
"""
|
||||
parent_seq_id: int
|
||||
output_token: int
|
||||
logprobs: Dict[int, Logprob]
|
||||
logprobs: dict[int, Logprob]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||
@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput(
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The model output associated with a completion sequence group."""
|
||||
__metaclass__ = SequenceGroupOutput
|
||||
samples: List[SequenceOutput]
|
||||
samples: list[SequenceOutput]
|
||||
# Prompt logprob for each prompt query token.
|
||||
prompt_logprobs: Optional[PromptLogprobs]
|
||||
|
||||
@@ -1119,7 +1119,7 @@ class IntermediateTensors:
|
||||
contains the hidden states and residuals for a request.
|
||||
"""
|
||||
|
||||
tensors: Dict[str, torch.Tensor]
|
||||
tensors: dict[str, torch.Tensor]
|
||||
|
||||
def __init__(self, tensors):
|
||||
# manually define this function, so that
|
||||
@@ -1155,7 +1155,7 @@ class PoolerOutput(
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The output from a pooling operation in the pooling model."""
|
||||
outputs: List[PoolingSequenceGroupOutput]
|
||||
outputs: list[PoolingSequenceGroupOutput]
|
||||
|
||||
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
@@ -1172,7 +1172,7 @@ class PoolerOutput(
|
||||
|
||||
|
||||
def get_all_seq_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
@@ -1180,13 +1180,13 @@ def get_all_seq_ids(
|
||||
|
||||
|
||||
def get_all_seq_ids_and_request_ids(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> Tuple[List[int], Dict[str, Set[int]]]:
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata]
|
||||
) -> tuple[list[int], dict[str, set[int]]]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all
|
||||
sequence ids.
|
||||
"""
|
||||
seq_ids: List[int] = []
|
||||
request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set)
|
||||
seq_ids: list[int] = []
|
||||
request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
|
||||
for sg in seq_group_metadata_list:
|
||||
for seq_id in sg.seq_data:
|
||||
seq_ids.append(seq_id)
|
||||
@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
||||
# all tokens, whereas for decode step, it use used for last accepted tokens.
|
||||
hidden_states: torch.Tensor
|
||||
# The sequence group metadata list. Only needed for decode step.
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
|
||||
# Scorer hidden states of the 2nd last token proposed by the proposer (
|
||||
# irrespective of whether it was accepted or not). Only used for cases when
|
||||
# last proposed token is accepted (i.e., in case of bonus tokens). For the
|
||||
# case of no bonus tokens, these are ignored.
|
||||
second_last_token_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
_seq_ids: List[int] = msgspec.field(default_factory=list)
|
||||
_seq_ids: list[int] = msgspec.field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.seq_group_metadata_list is not None:
|
||||
@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||
|
||||
@property
|
||||
def seq_ids(self) -> List[int]:
|
||||
def seq_ids(self) -> list[int]:
|
||||
return self._seq_ids
|
||||
|
||||
def update(self,
|
||||
hidden_states: torch.Tensor,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata],
|
||||
second_last_token_hidden_states: Optional[torch.Tensor] = None):
|
||||
"""Update hidden states from target model invocation. Only used for
|
||||
decode steps"""
|
||||
@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
||||
])
|
||||
|
||||
def prune(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
|
||||
"""Prune to provided list of sequence ids. Only used for decode steps.
|
||||
"""
|
||||
# Currently this prunes all seq_ids not present in
|
||||
@@ -1287,16 +1287,16 @@ class ExecuteModelRequest(
|
||||
"""The model execution request, containing CPU metadata only. The LLM
|
||||
engine should create an instance of this class for each request batch."""
|
||||
# The sequence group metadata list.
|
||||
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
seq_group_metadata_list: list[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]]
|
||||
# Blocks to swap in. List of CPU -> GPU block number.
|
||||
blocks_to_swap_in: List[Tuple[int,
|
||||
blocks_to_swap_in: list[tuple[int,
|
||||
int]] = msgspec.field(default_factory=list)
|
||||
# Blocks to swap out. List of GPU -> CPU block number.
|
||||
blocks_to_swap_out: List[Tuple[int,
|
||||
blocks_to_swap_out: list[tuple[int,
|
||||
int]] = msgspec.field(default_factory=list)
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
|
||||
blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
|
||||
# Virtual engine ID for pipeline parallel.
|
||||
virtual_engine: int = 0
|
||||
# The number of slots for lookahead decoding.
|
||||
@@ -1310,7 +1310,7 @@ class ExecuteModelRequest(
|
||||
# The step index for spec model input.
|
||||
spec_step_idx: Optional[int] = None
|
||||
# Finished request ids since last step.
|
||||
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
|
||||
finished_requests_ids: list[str] = msgspec.field(default_factory=list)
|
||||
# The last sampled token ids for multi step decoding.
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# Async callback
|
||||
@@ -1344,7 +1344,7 @@ class ExecuteModelRequest(
|
||||
return state.current_step
|
||||
|
||||
def clone(
|
||||
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||||
self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
|
||||
SequenceGroupMetadataDelta]]
|
||||
) -> "ExecuteModelRequest":
|
||||
"""Clone the request with a new sequence group metadata list."""
|
||||
@@ -1371,13 +1371,13 @@ class SequenceGroupBase:
|
||||
assembled_seq_group: Optional[SequenceGroup] = None
|
||||
|
||||
# seq id to a unique index inside this group
|
||||
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
|
||||
seq_id_to_index: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# seq ids to be finished
|
||||
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
|
||||
# seq id to finished sequences
|
||||
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user