[Mistral common] Ensure all functions are imported from the top & only use public methods (#31138)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Patrick von Platen
2025-12-26 13:48:24 +01:00
committed by GitHub
parent ce1eafd1a5
commit 48e744976c
5 changed files with 24 additions and 57 deletions

View File

@@ -3,6 +3,21 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.logger import init_logger
@@ -10,10 +25,6 @@ from vllm.logger import init_logger
from .protocol import TokenizerLike
if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers import BatchEncoding
try:
@@ -101,8 +112,6 @@ def _prepare_apply_chat_template_tools_and_messages(
continue_final_message: bool = False,
add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
from mistral_common.protocol.instruct.tool_calls import Function, Tool
if add_generation_prompt and continue_final_message:
raise ValueError(
"Cannot set both `add_generation_prompt` and "
@@ -181,8 +190,6 @@ def validate_request_params(request: "ChatCompletionRequest"):
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
@@ -210,8 +217,6 @@ class MistralTokenizer(TokenizerLike):
download_dir: str | None = None,
**kwargs,
) -> "MistralTokenizer":
from mistral_common.protocol.instruct.validator import ValidationMode
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
@@ -235,12 +240,6 @@ class MistralTokenizer(TokenizerLike):
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
super().__init__()
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
@@ -270,37 +269,20 @@ class MistralTokenizer(TokenizerLike):
# Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Vocab sorted by token id.
self._vocab = self.tokenizer.vocab()
self._max_token_id = self.vocab_size - 1
# Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids()
self._special_token_ids_set = set(self._special_token_ids)
self._special_tokens = self._get_special_tokens(self._special_token_ids)
self._special_tokens_set = set(self._special_tokens)
# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1
def _get_special_token_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
elif self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
special_ids = self.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in all_special_ids
@@ -460,15 +442,6 @@ class MistralTokenizer(TokenizerLike):
)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
@@ -523,12 +496,6 @@ class MistralTokenizer(TokenizerLike):
ids: list[int],
skip_special_tokens: bool = False,
) -> list[str]:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
if not skip_special_tokens:
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]