[Mistral common] Ensure all functions are imported from the top & only use public methods (#31138)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
ce1eafd1a5
commit
48e744976c
@@ -3,6 +3,21 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest as MistralChatCompletionRequest,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
from mistral_common.tokens.tokenizers.base import (
|
||||
SpecialTokenPolicy,
|
||||
SpecialTokens,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.logger import init_logger
|
||||
@@ -10,10 +25,6 @@ from vllm.logger import init_logger
|
||||
from .protocol import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest as MistralChatCompletionRequest,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from transformers import BatchEncoding
|
||||
|
||||
try:
|
||||
@@ -101,8 +112,6 @@ def _prepare_apply_chat_template_tools_and_messages(
|
||||
continue_final_message: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
if add_generation_prompt and continue_final_message:
|
||||
raise ValueError(
|
||||
"Cannot set both `add_generation_prompt` and "
|
||||
@@ -181,8 +190,6 @@ def validate_request_params(request: "ChatCompletionRequest"):
|
||||
|
||||
|
||||
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
|
||||
|
||||
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
|
||||
@@ -210,8 +217,6 @@ class MistralTokenizer(TokenizerLike):
|
||||
download_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> "MistralTokenizer":
|
||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
@@ -235,12 +240,6 @@ class MistralTokenizer(TokenizerLike):
|
||||
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
|
||||
super().__init__()
|
||||
|
||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
self.transformers_tokenizer = tokenizer
|
||||
self.mistral = tokenizer.tokenizer
|
||||
self.instruct = self.mistral.instruct_tokenizer
|
||||
@@ -270,37 +269,20 @@ class MistralTokenizer(TokenizerLike):
|
||||
# Sort the dict for convenience
|
||||
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
|
||||
|
||||
# Vocab sorted by token id.
|
||||
self._vocab = self.tokenizer.vocab()
|
||||
self._max_token_id = self.vocab_size - 1
|
||||
|
||||
# Cache special tokens for faster access.
|
||||
self._special_token_ids = self._get_special_token_ids()
|
||||
self._special_token_ids_set = set(self._special_token_ids)
|
||||
self._special_tokens = self._get_special_tokens(self._special_token_ids)
|
||||
self._special_tokens_set = set(self._special_tokens)
|
||||
|
||||
# Vocab sorted by token id.
|
||||
self._vocab = self.tokenizer._vocab
|
||||
self._max_token_id = self.vocab_size - 1
|
||||
|
||||
def _get_special_token_ids(self) -> list[int]:
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
|
||||
elif self.is_spm:
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
|
||||
self.tokenizer
|
||||
)
|
||||
special_ids = self.tokenizer._control_tokens
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
|
||||
return sorted(special_ids)
|
||||
return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]
|
||||
|
||||
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
return [
|
||||
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
|
||||
for i in all_special_ids
|
||||
@@ -460,15 +442,6 @@ class MistralTokenizer(TokenizerLike):
|
||||
)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
from mistral_common.tokens.tokenizers.base import (
|
||||
SpecialTokenPolicy,
|
||||
SpecialTokens,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
to_decode_special_tokens = {SpecialTokens.tool_calls}
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
@@ -523,12 +496,6 @@ class MistralTokenizer(TokenizerLike):
|
||||
ids: list[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import (
|
||||
SpecialTokenPolicy,
|
||||
SpecialTokens,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
|
||||
|
||||
if not skip_special_tokens:
|
||||
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user