Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Generator, List, Optional
from collections.abc import Generator
from typing import Any, Optional
import pytest
from transformers import AutoTokenizer
@@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> List[int]:
tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids
@@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None):
def create_dummy_logprobs(
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
return [{
token_id: Logprob(logprob=0.0),
token_id + 1: Logprob(logprob=0.1)
@@ -186,10 +187,10 @@ def create_dummy_logprobs(
def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]
) -> List[Optional[Dict[int, Any]]]:
complete_sequence_token_ids: list[int]
) -> list[Optional[dict[int, Any]]]:
# logprob for the first prompt token is None.
logprobs: List[Optional[Dict[int, Any]]] = [None]
logprobs: list[Optional[dict[int, Any]]] = [None]
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
return logprobs
@@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs(
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
def test_decode_sequence_logprobs(complete_sequence: str,
complete_sequence_token_ids: List[int],
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer,
skip_special_tokens: bool):
"""Verify Detokenizer decodes logprobs correctly."""
@@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
# Run sequentially.
seq = create_sequence()
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
sequential_logprobs_text_chosen_token: List[str] = []
sequential_logprobs_text_other_token: List[str] = []
sequential_logprobs_text_chosen_token: list[str] = []
sequential_logprobs_text_other_token: list[str] = []
for new_token, logprobs in zip(complete_sequence_token_ids,
dummy_logprobs):
seq.append_token_id(new_token, logprobs)
@@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
detokenizer: Detokenizer):
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True,
@@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
dummy_logprobs,
position_offset=0)
# First logprob is None.
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
1:] # type: ignore
# decoded_prompt_logprobs doesn't contain the first token.