Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -22,22 +22,22 @@ MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
def _remove_requests(
|
||||
input_batch: InputBatch, batch_size: int,
|
||||
reqs: List[CachedRequestState]) -> Tuple[Set[str], List[int]]:
|
||||
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
|
||||
"""
|
||||
Remove some requests randomly from the batch and returns a Tuple
|
||||
Remove some requests randomly from the batch and returns a tuple
|
||||
of 1) set of request removed 2) indices of the requests removed
|
||||
ordered in descending order
|
||||
"""
|
||||
|
||||
num_reqs_to_remove = np.random.randint(0, batch_size)
|
||||
req_indices_to_remove: Set[int] = set()
|
||||
req_indices_to_remove: set[int] = set()
|
||||
for _ in range(num_reqs_to_remove):
|
||||
req_index_to_remove = np.random.randint(0, batch_size)
|
||||
req_indices_to_remove.add(req_index_to_remove)
|
||||
|
||||
req_indices_to_remove_list = list(req_indices_to_remove)
|
||||
req_indices_to_remove_list.sort(reverse=True)
|
||||
req_ids_to_remove: Set[str] = set()
|
||||
req_ids_to_remove: set[str] = set()
|
||||
for index in req_indices_to_remove:
|
||||
input_batch.remove_request(reqs[index].req_id)
|
||||
req_ids_to_remove.add(reqs[index].req_id)
|
||||
@@ -45,9 +45,9 @@ def _remove_requests(
|
||||
|
||||
|
||||
def _construct_expected_sampling_metadata(
|
||||
reqs: List[CachedRequestState],
|
||||
req_ids_retained: Set[int],
|
||||
req_id_index_in_input_batch: Dict[str, int],
|
||||
reqs: list[CachedRequestState],
|
||||
req_ids_retained: set[int],
|
||||
req_id_index_in_input_batch: dict[str, int],
|
||||
device: torch.device,
|
||||
) -> SamplingMetadata:
|
||||
"""
|
||||
@@ -55,8 +55,8 @@ def _construct_expected_sampling_metadata(
|
||||
batch.
|
||||
"""
|
||||
num_reqs = len(req_ids_retained)
|
||||
output_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
|
||||
prompt_token_ids: List[List[int]] = [list() for _ in range(num_reqs)]
|
||||
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
|
||||
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
|
||||
presence_penalties = [0.0 for _ in range(num_reqs)]
|
||||
frequency_penalties = [0.0 for _ in range(num_reqs)]
|
||||
repetition_penalties = [1.0 for _ in range(num_reqs)]
|
||||
@@ -191,7 +191,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
)
|
||||
reqs: List[CachedRequestState] = []
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
req_id_output_token_ids = {}
|
||||
# Add requests
|
||||
|
||||
Reference in New Issue
Block a user