Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Set, Tuple
from typing import Optional
import numpy as np
import pytest
@@ -22,22 +22,22 @@ MAX_NUM_PROMPT_TOKENS = 64
def _remove_requests(
input_batch: InputBatch, batch_size: int,
reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]:
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
"""
Remove some requests randomly from the batch and returns a Tuple
Remove some requests randomly from the batch and returns a tuple
of 1) set of request removed 2) indices of the requests removed
ordered in descending order
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: Set[int] = set()
req_indices_to_remove: set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_indices_to_remove_list = list(req_indices_to_remove)
req_indices_to_remove_list.sort(reverse=True)
req_ids_to_remove: Set[str] = set()
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
@@ -45,9 +45,9 @@ def _remove_requests(
def _construct_expected_sampling_metadata(
reqs: List[CachedRequestState],
req_ids_retained: Set[int],
req_id_index_in_input_batch: Dict[str, int],
reqs: list[CachedRequestState],
req_ids_retained: set[int],
req_id_index_in_input_batch: dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
@@ -55,8 +55,8 @@ def _construct_expected_sampling_metadata(
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
@@ -191,7 +191,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
reqs: List[CachedRequestState] = []
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests