Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Set, Tuple
from typing import Optional
import numpy as np
import pytest
@@ -32,7 +32,7 @@ def _create_penalty_tensor(batch_size: int, penalty_value: float,
def _create_prompt_tokens_tensor(
prompt_token_ids: List[List[int]],
prompt_token_ids: list[list[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
@@ -49,8 +49,8 @@ def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
) -> list[Optional[dict[int, float]]]:
res: list[Optional[dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
@@ -83,8 +83,8 @@ def _create_default_sampling_metadata(
vocab_size: int,
device: torch.device,
) -> SamplingMetadata:
output_token_ids: List[List[int]] = []
prompt_token_ids: List[List[int]] = []
output_token_ids: list[list[int]] = []
prompt_token_ids: list[list[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
@@ -118,8 +118,8 @@ def _create_default_sampling_metadata(
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: List[int]
) -> Dict[int, Tuple[int, Set[int]]]:
batch_indices_for_min_token_penalty: list[int]
) -> dict[int, tuple[int, set[int]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
@@ -130,7 +130,7 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
min_tokens: dict[int, tuple[int, set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens[index] = (
@@ -147,7 +147,7 @@ def _generate_min_token_penalties_and_stop_tokens(
def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
"""
Creates an output token list where each token occurs a distinct
number of times.
@@ -157,7 +157,7 @@ def _create_weighted_output_token_list(
list, each with a different frequency.
Returns:
Tuple[List[List[int]], List[List[int]]]:
tuple[list[list[int]], list[list[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
@@ -165,8 +165,8 @@ def _create_weighted_output_token_list(
batch, ordered by their frequency in the corresponding output
list.
"""
output_token_ids: List[List[int]] = []
sorted_token_ids_in_output: List[List[int]] = []
output_token_ids: list[list[int]] = []
sorted_token_ids_in_output: list[list[int]] = []
for _ in range(batch_size):
distinct_token_ids = np.random.choice(vocab_size,
size=np.random.randint(1, 10),