Update deprecated type hinting in vllm/transformers_utils (#18058)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
@@ -28,7 +28,7 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class Encoding:
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
input_ids: Union[list[int], list[list[int]]]
|
||||
|
||||
|
||||
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
@@ -105,7 +105,7 @@ def validate_request_params(request: "ChatCompletionRequest"):
|
||||
"for Mistral tokenizers.")
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
|
||||
@@ -125,7 +125,7 @@ def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def find_tokenizer_file(files: List[str]):
|
||||
def find_tokenizer_file(files: list[str]):
|
||||
file_pattern = re.compile(
|
||||
r"^tokenizer\.model\.v.*$|^tekken\.json$|^tokenizer\.mm\.model\.v.*$")
|
||||
|
||||
@@ -145,10 +145,10 @@ def find_tokenizer_file(files: List[str]):
|
||||
|
||||
|
||||
def make_mistral_chat_completion_request(
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str,
|
||||
Any]]] = None) -> "ChatCompletionRequest":
|
||||
last_message = cast(Dict[str, Any], messages[-1])
|
||||
last_message = cast(dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
@@ -199,7 +199,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||
|
||||
self._vocab = tokenizer_.vocab()
|
||||
# Convert to a Dict[str, int] to match protocol, but this is a lossy
|
||||
# Convert to a dict[str, int] to match protocol, but this is a lossy
|
||||
# conversion. There may be multiple token ids that decode to the same
|
||||
# string due to partial UTF-8 byte sequences being converted to <20>
|
||||
self._vocab_dict = {
|
||||
@@ -314,21 +314,21 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str], List[int]],
|
||||
text: Union[str, list[str], list[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
# For List[str], original prompt text
|
||||
input_ids: Union[list[int], list[list[int]]]
|
||||
# For list[str], original prompt text
|
||||
if is_list_of(text, str):
|
||||
input_ids_: List[List[int]] = []
|
||||
input_ids_: list[list[int]] = []
|
||||
for p in text:
|
||||
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||
input_ids_.append(each_input_ids)
|
||||
input_ids = input_ids_
|
||||
# For List[int], apply chat template output, already tokens.
|
||||
# For list[int], apply chat template output, already tokens.
|
||||
elif is_list_of(text, int):
|
||||
input_ids = text
|
||||
# For str, single prompt text
|
||||
@@ -350,7 +350,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(text)
|
||||
|
||||
@@ -362,7 +362,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
text: str,
|
||||
truncation: Optional[bool] = None,
|
||||
max_length: Optional[int] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
# For chat completion use `apply_chat_template`
|
||||
@@ -374,9 +374,9 @@ class MistralTokenizer(TokenizerBase):
|
||||
return self.tokenizer.encode(text, bos=True, eos=False)
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
**kwargs) -> list[int]:
|
||||
|
||||
request = make_mistral_chat_completion_request(messages, tools)
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
@@ -384,7 +384,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
# encode-decode to get clean prompt
|
||||
return encoded.tokens
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
if self.is_tekken:
|
||||
tokens = [
|
||||
@@ -417,7 +417,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
# make sure certain special tokens like Tool calls are
|
||||
# not decoded
|
||||
special_tokens = {SpecialTokens.tool_calls}
|
||||
regular_tokens: List[str] = []
|
||||
regular_tokens: list[str] = []
|
||||
decoded_list = []
|
||||
|
||||
for token in tokens:
|
||||
@@ -442,7 +442,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
|
||||
# for more.
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
ids: Union[list[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
assert (
|
||||
skip_special_tokens
|
||||
@@ -454,9 +454,9 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
ids: list[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
|
||||
Reference in New Issue
Block a user