Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, List, Tuple, Union
from typing import Callable, Union
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
Callable[[list[int], list[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
@@ -17,9 +17,9 @@ to sample from."""
def get_bad_words_logits_processors(
bad_words: List[str],
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
bad_words_ids: List[List[int]] = list()
bad_words: list[str],
tokenizer: AnyTokenizer) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list()
for bad_word in bad_words:
# To prohibit words both at the beginning
@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT = float("-inf")
_NEUTRAL_LOGIT = 0.0
def __init__(self, bad_words_ids: List[List[int]]):
def __init__(self, bad_words_ids: list[list[int]]):
self.bad_words_ids = bad_words_ids
self.word_bias: torch.FloatTensor = None
def __call__(
self,
past_tokens_ids: Union[List[int], Tuple[int]],
past_tokens_ids: Union[list[int], tuple[int]],
logits: torch.FloatTensor,
) -> torch.Tensor:
if self.word_bias is None: