[Mistral common] Ensure all functions are imported from the top & only use public methods (#31138)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Patrick von Platen
2025-12-26 13:48:24 +01:00
committed by GitHub
parent ce1eafd1a5
commit 48e744976c
5 changed files with 24 additions and 57 deletions

View File

@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0 pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.17.0 gguf >= 0.17.0
mistral_common[image] >= 1.8.5 mistral_common[image] >= 1.8.8
opencv-python-headless >= 4.11.0 # required for video IO opencv-python-headless >= 4.11.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

View File

@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test mistral_common[image,audio] >= 1.8.8 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test

View File

@@ -29,7 +29,7 @@ torchaudio==2.9.1
torchvision==0.24.1 torchvision==0.24.1
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test mistral_common[image,audio] >= 1.8.8 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test

View File

@@ -482,7 +482,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.8.5 mistral-common==1.8.8
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0 mlflow==2.22.0
# via terratorch # via terratorch

View File

@@ -3,6 +3,21 @@
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.logger import init_logger from vllm.logger import init_logger
@@ -10,10 +25,6 @@ from vllm.logger import init_logger
from .protocol import TokenizerLike from .protocol import TokenizerLike
if TYPE_CHECKING: if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers import BatchEncoding from transformers import BatchEncoding
try: try:
@@ -101,8 +112,6 @@ def _prepare_apply_chat_template_tools_and_messages(
continue_final_message: bool = False, continue_final_message: bool = False,
add_generation_prompt: bool = False, add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]: ) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
from mistral_common.protocol.instruct.tool_calls import Function, Tool
if add_generation_prompt and continue_final_message: if add_generation_prompt and continue_final_message:
raise ValueError( raise ValueError(
"Cannot set both `add_generation_prompt` and " "Cannot set both `add_generation_prompt` and "
@@ -181,8 +190,6 @@ def validate_request_params(request: "ChatCompletionRequest"):
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
assert isinstance(tokenizer, Tekkenizer), type(tokenizer) assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
@@ -210,8 +217,6 @@ class MistralTokenizer(TokenizerLike):
download_dir: str | None = None, download_dir: str | None = None,
**kwargs, **kwargs,
) -> "MistralTokenizer": ) -> "MistralTokenizer":
from mistral_common.protocol.instruct.validator import ValidationMode
try: try:
# Transformers v5 # Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend from transformers.tokenization_mistral_common import MistralCommonBackend
@@ -235,12 +240,6 @@ class MistralTokenizer(TokenizerLike):
def __init__(self, tokenizer: "MistralCommonBackend") -> None: def __init__(self, tokenizer: "MistralCommonBackend") -> None:
super().__init__() super().__init__()
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.transformers_tokenizer = tokenizer self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer self.instruct = self.mistral.instruct_tokenizer
@@ -270,37 +269,20 @@ class MistralTokenizer(TokenizerLike):
# Sort the dict for convenience # Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1])) self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Vocab sorted by token id.
self._vocab = self.tokenizer.vocab()
self._max_token_id = self.vocab_size - 1
# Cache special tokens for faster access. # Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids() self._special_token_ids = self._get_special_token_ids()
self._special_token_ids_set = set(self._special_token_ids) self._special_token_ids_set = set(self._special_token_ids)
self._special_tokens = self._get_special_tokens(self._special_token_ids) self._special_tokens = self._get_special_tokens(self._special_token_ids)
self._special_tokens_set = set(self._special_tokens) self._special_tokens_set = set(self._special_tokens)
# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1
def _get_special_token_ids(self) -> list[int]: def _get_special_token_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import ( return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
elif self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
special_ids = self.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]: def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [ return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in all_special_ids for i in all_special_ids
@@ -460,15 +442,6 @@ class MistralTokenizer(TokenizerLike):
) )
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
to_decode_special_tokens = {SpecialTokens.tool_calls} to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken: if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
@@ -523,12 +496,6 @@ class MistralTokenizer(TokenizerLike):
ids: list[int], ids: list[int],
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
) -> list[str]: ) -> list[str]:
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
if not skip_special_tokens: if not skip_special_tokens:
return [self.tokenizer.id_to_piece(token_id) for token_id in ids] return [self.tokenizer.id_to_piece(token_id) for token_id in ids]