Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase
|
||||
@@ -90,10 +90,10 @@ def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
|
||||
|
||||
def _prepare_apply_chat_template_tools_and_messages(
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
continue_final_message: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
) -> tuple[list["ChatCompletionMessageParam"], Optional[list[dict[str, Any]]]]:
|
||||
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
|
||||
if add_generation_prompt and continue_final_message:
|
||||
raise ValueError(
|
||||
"Cannot set both `add_generation_prompt` and "
|
||||
@@ -144,7 +144,7 @@ def validate_request_params(request: "ChatCompletionRequest"):
|
||||
raise ValueError("chat_template is not supported for Mistral tokenizers.")
|
||||
|
||||
|
||||
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: Union[str, bytes]) -> int:
|
||||
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
|
||||
@@ -197,7 +197,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, path_or_repo_id: str, *, revision: Optional[str] = None
|
||||
cls, path_or_repo_id: str, *, revision: str | None = None
|
||||
) -> "MistralTokenizer":
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as TransformersMistralTokenizer,
|
||||
@@ -298,11 +298,11 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, list[str], list[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
text: str | list[str] | list[int],
|
||||
text_pair: str | None = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
max_length: int | None = None,
|
||||
):
|
||||
return self.transformers_tokenizer(
|
||||
text=text,
|
||||
@@ -327,7 +327,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
max_length: int | None = None,
|
||||
) -> list[int]:
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
return self.transformers_tokenizer.encode(
|
||||
@@ -337,9 +337,9 @@ class MistralTokenizer(TokenizerBase):
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
truncation: Optional[bool] = None,
|
||||
max_length: Optional[int] = None,
|
||||
add_special_tokens: Optional[bool] = None,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
add_special_tokens: bool | None = None,
|
||||
) -> list[int]:
|
||||
if add_special_tokens is not None:
|
||||
return self.transformers_tokenizer.encode(
|
||||
@@ -359,7 +359,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
add_generation_prompt = kwargs.pop("add_generation_prompt", False)
|
||||
@@ -384,9 +384,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
def decode(
|
||||
self, ids: Union[list[int], int], skip_special_tokens: bool = True
|
||||
) -> str:
|
||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
|
||||
return self.transformers_tokenizer.decode(
|
||||
ids, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user