Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -32,7 +32,7 @@ def _create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
|
||||
|
||||
def _create_prompt_tokens_tensor(
|
||||
prompt_token_ids: List[List[int]],
|
||||
prompt_token_ids: list[list[int]],
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
@@ -49,8 +49,8 @@ def _create_logit_bias(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
bias_value: float,
|
||||
) -> List[Optional[Dict[int, float]]]:
|
||||
res: List[Optional[Dict[int, float]]] = []
|
||||
) -> list[Optional[dict[int, float]]]:
|
||||
res: list[Optional[dict[int, float]]] = []
|
||||
for i in range(batch_size):
|
||||
logit_bias = {min(i, vocab_size - 1): bias_value}
|
||||
res.append(logit_bias)
|
||||
@@ -83,8 +83,8 @@ def _create_default_sampling_metadata(
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> SamplingMetadata:
|
||||
output_token_ids: List[List[int]] = []
|
||||
prompt_token_ids: List[List[int]] = []
|
||||
output_token_ids: list[list[int]] = []
|
||||
prompt_token_ids: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
output_token_ids.append(
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
|
||||
@@ -118,8 +118,8 @@ def _create_default_sampling_metadata(
|
||||
|
||||
def _generate_min_token_penalties_and_stop_tokens(
|
||||
num_output_tokens: int, batch_size: int, vocab_size: int,
|
||||
batch_indices_for_min_token_penalty: List[int]
|
||||
) -> Dict[int, Tuple[int, Set[int]]]:
|
||||
batch_indices_for_min_token_penalty: list[int]
|
||||
) -> dict[int, tuple[int, set[int]]]:
|
||||
"""
|
||||
Generates and returns a dict of minimum token penalties and
|
||||
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
|
||||
@@ -130,7 +130,7 @@ def _generate_min_token_penalties_and_stop_tokens(
|
||||
and a random set of stop token IDs is created. Otherwise, a lower
|
||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||
"""
|
||||
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
|
||||
min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
for index in range(batch_size):
|
||||
if index in batch_indices_for_min_token_penalty:
|
||||
min_tokens[index] = (
|
||||
@@ -147,7 +147,7 @@ def _generate_min_token_penalties_and_stop_tokens(
|
||||
|
||||
def _create_weighted_output_token_list(
|
||||
batch_size: int,
|
||||
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
|
||||
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
|
||||
"""
|
||||
Creates an output token list where each token occurs a distinct
|
||||
number of times.
|
||||
@@ -157,7 +157,7 @@ def _create_weighted_output_token_list(
|
||||
list, each with a different frequency.
|
||||
|
||||
Returns:
|
||||
Tuple[List[List[int]], List[List[int]]]:
|
||||
tuple[list[list[int]], list[list[int]]]:
|
||||
- The first element is the output token list, where each sublist
|
||||
corresponds to a batch and contains tokens with weighted
|
||||
frequencies.
|
||||
@@ -165,8 +165,8 @@ def _create_weighted_output_token_list(
|
||||
batch, ordered by their frequency in the corresponding output
|
||||
list.
|
||||
"""
|
||||
output_token_ids: List[List[int]] = []
|
||||
sorted_token_ids_in_output: List[List[int]] = []
|
||||
output_token_ids: list[list[int]] = []
|
||||
sorted_token_ids_in_output: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
distinct_token_ids = np.random.choice(vocab_size,
|
||||
size=np.random.randint(1, 10),
|
||||
|
||||
Reference in New Issue
Block a user