Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
||||
Callable[[List[int], List[int], torch.Tensor],
|
||||
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
|
||||
Callable[[list[int], list[int], torch.Tensor],
|
||||
torch.Tensor]]
|
||||
"""LogitsProcessor is a function that takes a list
|
||||
of previously generated tokens, the logits tensor
|
||||
@@ -17,9 +17,9 @@ to sample from."""
|
||||
|
||||
|
||||
def get_bad_words_logits_processors(
|
||||
bad_words: List[str],
|
||||
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
|
||||
bad_words_ids: List[List[int]] = list()
|
||||
bad_words: list[str],
|
||||
tokenizer: AnyTokenizer) -> list[LogitsProcessor]:
|
||||
bad_words_ids: list[list[int]] = list()
|
||||
|
||||
for bad_word in bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor:
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
_NEUTRAL_LOGIT = 0.0
|
||||
|
||||
def __init__(self, bad_words_ids: List[List[int]]):
|
||||
def __init__(self, bad_words_ids: list[list[int]]):
|
||||
self.bad_words_ids = bad_words_ids
|
||||
self.word_bias: torch.FloatTensor = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
past_tokens_ids: Union[List[int], Tuple[int]],
|
||||
past_tokens_ids: Union[list[int], tuple[int]],
|
||||
logits: torch.FloatTensor,
|
||||
) -> torch.Tensor:
|
||||
if self.word_bias is None:
|
||||
|
||||
Reference in New Issue
Block a user