Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@@ -163,7 +164,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||
tokenizer) -> List[int]:
|
||||
tokenizer) -> list[int]:
|
||||
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
|
||||
return complete_sequence_token_ids
|
||||
|
||||
@@ -178,7 +179,7 @@ def create_sequence(prompt_token_ids=None):
|
||||
|
||||
|
||||
def create_dummy_logprobs(
|
||||
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
|
||||
complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
|
||||
return [{
|
||||
token_id: Logprob(logprob=0.0),
|
||||
token_id + 1: Logprob(logprob=0.1)
|
||||
@@ -186,10 +187,10 @@ def create_dummy_logprobs(
|
||||
|
||||
|
||||
def create_dummy_prompt_logprobs(
|
||||
complete_sequence_token_ids: List[int]
|
||||
) -> List[Optional[Dict[int, Any]]]:
|
||||
complete_sequence_token_ids: list[int]
|
||||
) -> list[Optional[dict[int, Any]]]:
|
||||
# logprob for the first prompt token is None.
|
||||
logprobs: List[Optional[Dict[int, Any]]] = [None]
|
||||
logprobs: list[Optional[dict[int, Any]]] = [None]
|
||||
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
|
||||
return logprobs
|
||||
|
||||
@@ -198,7 +199,7 @@ def create_dummy_prompt_logprobs(
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
|
||||
def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: List[int],
|
||||
complete_sequence_token_ids: list[int],
|
||||
detokenizer: Detokenizer,
|
||||
skip_special_tokens: bool):
|
||||
"""Verify Detokenizer decodes logprobs correctly."""
|
||||
@@ -208,8 +209,8 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
# Run sequentially.
|
||||
seq = create_sequence()
|
||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||
sequential_logprobs_text_chosen_token: List[str] = []
|
||||
sequential_logprobs_text_other_token: List[str] = []
|
||||
sequential_logprobs_text_chosen_token: list[str] = []
|
||||
sequential_logprobs_text_other_token: list[str] = []
|
||||
for new_token, logprobs in zip(complete_sequence_token_ids,
|
||||
dummy_logprobs):
|
||||
seq.append_token_id(new_token, logprobs)
|
||||
@@ -232,7 +233,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
|
||||
detokenizer: Detokenizer):
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=True,
|
||||
@@ -249,7 +250,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
||||
dummy_logprobs,
|
||||
position_offset=0)
|
||||
# First logprob is None.
|
||||
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
|
||||
decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
|
||||
1:] # type: ignore
|
||||
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
|
||||
Reference in New Issue
Block a user