Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -5,11 +5,11 @@ import enum
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass, field
from functools import reduce
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
from typing import Any, Callable, Optional, Union
import msgspec
import torch
@@ -50,9 +50,9 @@ class Logprob:
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
PromptLogprobs = list[Optional[dict[int, Logprob]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = List[Dict[int, Logprob]]
SampleLogprobs = list[dict[int, Logprob]]
class SequenceStatus(enum.IntEnum):
@@ -129,7 +129,7 @@ class SequenceDataDelta(
omit_defaults=True): # type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
new_output_token_ids: List[int]
new_output_token_ids: list[int]
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob: float
# Overwriting existing `num_computed_tokens`.
@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct,
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
# NOTE: we cannot use Union[List, array] because msgspec cannot support
# NOTE: we cannot use Union[list, array] because msgspec cannot support
# union of 2 list types.
_prompt_token_ids: array
_output_token_ids: array = msgspec.field(
@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct,
### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Tuple[int,
_prompt_token_ids_tuple: tuple[int,
...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0
# The number of tokens with prefix cache hit.
_num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
_new_appended_tokens: List[int] = msgspec.field(default_factory=list)
_new_appended_tokens: list[int] = msgspec.field(default_factory=list)
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None
@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
*token_counts: tuple[int, int]) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct,
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
self._prompt_token_ids)
self._update_cached_all_tokens()
def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
assert isinstance(self._output_token_ids, array)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
self._output_token_ids)
@property
@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct,
return self._cumulative_logprob
@property
def prompt_token_ids(self) -> Tuple[int, ...]:
def prompt_token_ids(self) -> tuple[int, ...]:
return self._prompt_token_ids_tuple
@prompt_token_ids.setter
@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct,
return self._prompt_token_ids
@property
def output_token_ids(self) -> Tuple[int, ...]:
def output_token_ids(self) -> tuple[int, ...]:
return tuple(self._output_token_ids)
@output_token_ids.setter
@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct,
def get_output_len(self) -> int:
return len(self._output_token_ids)
def get_token_ids(self) -> List[int]:
def get_token_ids(self) -> list[int]:
return self._cached_all_token_ids
def get_prefix_token_ids(
self, num_tokens: int
) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]:
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length = self.get_prompt_len()
if num_tokens > prompt_length:
@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct,
return self._prompt_token_ids[-1]
return self._output_token_ids[-1]
def get_prompt_token_ids(self) -> Tuple[int, ...]:
def get_prompt_token_ids(self) -> tuple[int, ...]:
return self.prompt_token_ids
def get_output_token_ids(self) -> Tuple[int, ...]:
def get_output_token_ids(self) -> tuple[int, ...]:
return self.output_token_ids
def get_delta_and_reset(self) -> SequenceDataDelta:
@@ -432,7 +432,7 @@ class Sequence:
self.prefix_offset = 0
self.read_offset = 0
# Input + output tokens
self.tokens: Optional[List[str]] = None
self.tokens: Optional[list[str]] = None
@property
def n_blocks(self) -> int:
@@ -443,7 +443,7 @@ class Sequence:
return self.inputs.prompt
@property
def prompt_token_ids(self) -> List[int]:
def prompt_token_ids(self) -> list[int]:
return self.inputs.prompt_token_ids
@property
@@ -451,7 +451,7 @@ class Sequence:
return self.inputs.prompt_embeds
@property
def token_type_ids(self) -> List[int]:
def token_type_ids(self) -> list[int]:
return self.inputs.token_type_ids
@property
@@ -463,7 +463,7 @@ class Sequence:
return self.inputs.multi_modal_placeholders
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
def mm_processor_kwargs(self) -> dict[str, Any]:
return self.inputs.mm_processor_kwargs
@property
@@ -548,7 +548,7 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(self, token_id: int, logprobs: Dict[int,
def append_token_id(self, token_id: int, logprobs: dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
@@ -563,16 +563,16 @@ class Sequence:
def get_output_len(self) -> int:
return self.data.get_output_len()
def get_token_ids(self) -> List[int]:
def get_token_ids(self) -> list[int]:
return self.data.get_token_ids()
def get_prompt_token_ids(self) -> Tuple[int, ...]:
def get_prompt_token_ids(self) -> tuple[int, ...]:
return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
def get_output_token_ids(self) -> Tuple[int, ...]:
def get_output_token_ids(self) -> tuple[int, ...]:
return self.data.get_output_token_ids()
def get_cumulative_logprob(self) -> float:
@@ -644,7 +644,7 @@ class SequenceGroup:
def __init__(
self,
request_id: str,
seqs: List[Sequence],
seqs: list[Sequence],
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
@@ -686,7 +686,7 @@ class SequenceGroup:
return self.first_seq.prompt
@property
def prompt_token_ids(self) -> List[int]:
def prompt_token_ids(self) -> list[int]:
return self.first_seq.prompt_token_ids
@property
@@ -698,7 +698,7 @@ class SequenceGroup:
if self.encoder_seq is not None else None)
@property
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
def encoder_prompt_token_ids(self) -> Optional[list[int]]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are
# distinct from the decoder's.
@@ -706,7 +706,7 @@ class SequenceGroup:
if self.encoder_seq is not None else None)
@property
def token_type_ids(self) -> Optional[List[int]]:
def token_type_ids(self) -> Optional[list[int]]:
return self.first_seq.token_type_ids
@property
@@ -726,7 +726,7 @@ class SequenceGroup:
return {}
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
def mm_processor_kwargs(self) -> dict[str, Any]:
if self.first_seq.multi_modal_data:
return self.first_seq.mm_processor_kwargs
elif self.encoder_seq is not None:
@@ -823,7 +823,7 @@ class SequenceGroup:
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
) -> list[Sequence]:
if status is None:
return self.seqs
@@ -838,7 +838,7 @@ class SequenceGroup:
def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq
def get_finished_seqs(self) -> List[Sequence]:
def get_finished_seqs(self) -> list[Sequence]:
if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else []
@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta(
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
"""
seq_data_delta: Dict[int, SequenceDataDelta]
seq_data_delta: dict[int, SequenceDataDelta]
request_id: str
block_tables: Dict[int, List[int]]
block_tables: dict[int, list[int]]
is_prompt: bool
do_sample: bool = True
token_chunk_size: Optional[int] = None
computed_block_nums: Optional[List[int]] = None
computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
@@ -947,23 +947,23 @@ class SequenceGroupMetadata(
request_id: str
is_prompt: bool
seq_data: Dict[int, SequenceData]
seq_data: dict[int, SequenceData]
sampling_params: Optional[SamplingParams]
block_tables: Dict[int, List[int]]
block_tables: dict[int, list[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
lora_request: Optional[LoRARequest] = None
computed_block_nums: Optional[List[int]] = None
computed_block_nums: Optional[list[int]] = None
state: Optional[SequenceGroupState] = msgspec.field(
default_factory=lambda: SequenceGroupState())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
token_type_ids: Optional[List[int]] = None
token_type_ids: Optional[list[int]] = None
multi_modal_data: Optional[Any] = None
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[dict[str, Any]] = None
encoder_seq_data: Optional[SequenceData] = None
cross_block_table: Optional[List[int]] = None
cross_block_table: Optional[list[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None
@@ -1042,7 +1042,7 @@ class SequenceOutput(
"""
parent_seq_id: int
output_token: int
logprobs: Dict[int, Logprob]
logprobs: dict[int, Logprob]
def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput(
array_like=True): # type: ignore[call-arg]
"""The model output associated with a completion sequence group."""
__metaclass__ = SequenceGroupOutput
samples: List[SequenceOutput]
samples: list[SequenceOutput]
# Prompt logprob for each prompt query token.
prompt_logprobs: Optional[PromptLogprobs]
@@ -1119,7 +1119,7 @@ class IntermediateTensors:
contains the hidden states and residuals for a request.
"""
tensors: Dict[str, torch.Tensor]
tensors: dict[str, torch.Tensor]
def __init__(self, tensors):
# manually define this function, so that
@@ -1155,7 +1155,7 @@ class PoolerOutput(
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs: List[PoolingSequenceGroupOutput]
outputs: list[PoolingSequenceGroupOutput]
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx]
@@ -1172,7 +1172,7 @@ class PoolerOutput(
def get_all_seq_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
@@ -1180,13 +1180,13 @@ def get_all_seq_ids(
def get_all_seq_ids_and_request_ids(
seq_group_metadata_list: List[SequenceGroupMetadata]
) -> Tuple[List[int], Dict[str, Set[int]]]:
seq_group_metadata_list: list[SequenceGroupMetadata]
) -> tuple[list[int], dict[str, set[int]]]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
seq_ids: List[int] = []
request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set)
seq_ids: list[int] = []
request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set)
for sg in seq_group_metadata_list:
for seq_id in sg.seq_data:
seq_ids.append(seq_id)
@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states: torch.Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
# Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states: Optional[torch.Tensor] = None
_seq_ids: List[int] = msgspec.field(default_factory=list)
_seq_ids: list[int] = msgspec.field(default_factory=list)
def __post_init__(self):
if self.seq_group_metadata_list is not None:
@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True,
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
@property
def seq_ids(self) -> List[int]:
def seq_ids(self) -> list[int]:
return self._seq_ids
def update(self,
hidden_states: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_group_metadata_list: list[SequenceGroupMetadata],
second_last_token_hidden_states: Optional[torch.Tensor] = None):
"""Update hidden states from target model invocation. Only used for
decode steps"""
@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True,
])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
@@ -1287,16 +1287,16 @@ class ExecuteModelRequest(
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
seq_group_metadata_list: list[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int,
blocks_to_swap_in: list[tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int,
blocks_to_swap_out: list[tuple[int,
int]] = msgspec.field(default_factory=list)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list)
blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding.
@@ -1310,7 +1310,7 @@ class ExecuteModelRequest(
# The step index for spec model input.
spec_step_idx: Optional[int] = None
# Finished request ids since last step.
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
finished_requests_ids: list[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
@@ -1344,7 +1344,7 @@ class ExecuteModelRequest(
return state.current_step
def clone(
self, seq_group_metadata_list: List[Union[SequenceGroupMetadata,
self, seq_group_metadata_list: list[Union[SequenceGroupMetadata,
SequenceGroupMetadataDelta]]
) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list."""
@@ -1371,13 +1371,13 @@ class SequenceGroupBase:
assembled_seq_group: Optional[SequenceGroup] = None
# seq id to a unique index inside this group
seq_id_to_index: Dict[str, int] = field(default_factory=dict)
seq_id_to_index: dict[str, int] = field(default_factory=dict)
# seq ids to be finished
to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict)
# seq id to finished sequences
finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict)
streaming: bool = False