Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from itertools import count
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import TypeVar, Union
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
@@ -44,7 +43,7 @@ def mock_worker(cls=None,
|
||||
return worker
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
|
||||
@@ -56,7 +55,7 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: List[CacheEngine]):
|
||||
def zero_kv_cache(cache_engine: list[CacheEngine]):
|
||||
assert cache_engine[0].gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
|
||||
key_blocks.zero_()
|
||||
@@ -106,13 +105,13 @@ def create_worker(cls: Callable[..., T],
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: List[List[int]],
|
||||
prompts: list[list[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_prompt_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
final_prompt_lens: list[int],
|
||||
continuations: Optional[list[list[int]]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
@@ -149,11 +148,11 @@ def create_seq_group_metadata_from_prompts(
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: List[int],
|
||||
prompt: list[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
|
||||
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
@@ -184,8 +183,8 @@ def create_chunked_seq_group_metadata_from_prompt(
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: List[Dict[int, Logprob]],
|
||||
expected_logprobs: List[Dict[int, Logprob]]) -> None:
|
||||
actual_logprobs: list[dict[int, Logprob]],
|
||||
expected_logprobs: list[dict[int, Logprob]]) -> None:
|
||||
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
|
||||
actual_logprobs, expected_logprobs):
|
||||
assert set(single_step_actual_logprobs.keys()) == set(
|
||||
@@ -202,7 +201,7 @@ def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: GenericSequence[Optional[torch.Tensor]],
|
||||
logprobs: GenericSequence[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
|
||||
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
@@ -231,9 +230,9 @@ def create_sampler_output_list(
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, List[int]] = 10,
|
||||
prompt_len: Union[int, list[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
|
||||
Reference in New Issue
Block a user