Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence as GenericSequence
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from typing import Callable, Optional, TypeVar, Union
from unittest.mock import MagicMock
import torch
@@ -44,7 +43,7 @@ def mock_worker(cls=None,
return worker
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return new_execute_model
def zero_kv_cache(cache_engine: List[CacheEngine]):
def zero_kv_cache(cache_engine: list[CacheEngine]):
assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_()
@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T],
def create_seq_group_metadata_from_prompts(
prompts: List[List[int]],
prompts: list[list[int]],
num_gpu_blocks: int,
block_size: int,
final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
final_prompt_lens: list[int],
continuations: Optional[list[list[int]]] = None,
seq_ids: Optional[list[int]] = None,
) -> list[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts(
def create_chunked_seq_group_metadata_from_prompt(
prompt: List[int],
prompt: list[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt(
def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
actual_logprobs: list[dict[int, Logprob]],
expected_logprobs: list[dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
@@ -202,7 +201,7 @@ def create_sampler_output_list(
token_ids: torch.Tensor,
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
@@ -231,9 +230,9 @@ def create_sampler_output_list(
def create_batch(batch_size,
k,
prompt_len: Union[int, List[int]] = 10,
prompt_len: Union[int, list[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None,
seq_ids: Optional[list[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):