[Mistral common] Ensure all functions are imported from the top & only use public methods (#31138)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
ce1eafd1a5
commit
48e744976c
@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
|
|||||||
pyzmq >= 25.0.0
|
pyzmq >= 25.0.0
|
||||||
msgspec
|
msgspec
|
||||||
gguf >= 0.17.0
|
gguf >= 0.17.0
|
||||||
mistral_common[image] >= 1.8.5
|
mistral_common[image] >= 1.8.8
|
||||||
opencv-python-headless >= 4.11.0 # required for video IO
|
opencv-python-headless >= 4.11.0 # required for video IO
|
||||||
pyyaml
|
pyyaml
|
||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ jiwer # required for audio tests
|
|||||||
timm # required for internvl test
|
timm # required for internvl test
|
||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
|
mistral_common[image,audio] >= 1.8.8 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
opencv-python-headless >= 4.11.0 # required for video test
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ torchaudio==2.9.1
|
|||||||
torchvision==0.24.1
|
torchvision==0.24.1
|
||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
|
mistral_common[image,audio] >= 1.8.8 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||||
opencv-python-headless >= 4.11.0 # required for video test
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
|
|||||||
@@ -482,7 +482,7 @@ mbstrdecoder==1.1.3
|
|||||||
# typepy
|
# typepy
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
# via markdown-it-py
|
# via markdown-it-py
|
||||||
mistral-common==1.8.5
|
mistral-common==1.8.8
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
mlflow==2.22.0
|
mlflow==2.22.0
|
||||||
# via terratorch
|
# via terratorch
|
||||||
|
|||||||
@@ -3,6 +3,21 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from mistral_common.protocol.instruct.request import (
|
||||||
|
ChatCompletionRequest as MistralChatCompletionRequest,
|
||||||
|
)
|
||||||
|
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||||
|
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||||
|
from mistral_common.tokens.tokenizers.base import (
|
||||||
|
SpecialTokenPolicy,
|
||||||
|
SpecialTokens,
|
||||||
|
)
|
||||||
|
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
|
||||||
|
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||||
|
SentencePieceTokenizer,
|
||||||
|
)
|
||||||
|
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@@ -10,10 +25,6 @@ from vllm.logger import init_logger
|
|||||||
from .protocol import TokenizerLike
|
from .protocol import TokenizerLike
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mistral_common.protocol.instruct.request import (
|
|
||||||
ChatCompletionRequest as MistralChatCompletionRequest,
|
|
||||||
)
|
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
|
||||||
from transformers import BatchEncoding
|
from transformers import BatchEncoding
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -101,8 +112,6 @@ def _prepare_apply_chat_template_tools_and_messages(
|
|||||||
continue_final_message: bool = False,
|
continue_final_message: bool = False,
|
||||||
add_generation_prompt: bool = False,
|
add_generation_prompt: bool = False,
|
||||||
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
|
) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]:
|
||||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
|
||||||
|
|
||||||
if add_generation_prompt and continue_final_message:
|
if add_generation_prompt and continue_final_message:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot set both `add_generation_prompt` and "
|
"Cannot set both `add_generation_prompt` and "
|
||||||
@@ -181,8 +190,6 @@ def validate_request_params(request: "ChatCompletionRequest"):
|
|||||||
|
|
||||||
|
|
||||||
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
|
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
|
||||||
|
|
||||||
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
|
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
|
||||||
|
|
||||||
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
|
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
|
||||||
@@ -210,8 +217,6 @@ class MistralTokenizer(TokenizerLike):
|
|||||||
download_dir: str | None = None,
|
download_dir: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> "MistralTokenizer":
|
) -> "MistralTokenizer":
|
||||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Transformers v5
|
# Transformers v5
|
||||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||||
@@ -235,12 +240,6 @@ class MistralTokenizer(TokenizerLike):
|
|||||||
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
|
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
|
||||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
|
||||||
SentencePieceTokenizer,
|
|
||||||
)
|
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
|
||||||
|
|
||||||
self.transformers_tokenizer = tokenizer
|
self.transformers_tokenizer = tokenizer
|
||||||
self.mistral = tokenizer.tokenizer
|
self.mistral = tokenizer.tokenizer
|
||||||
self.instruct = self.mistral.instruct_tokenizer
|
self.instruct = self.mistral.instruct_tokenizer
|
||||||
@@ -270,37 +269,20 @@ class MistralTokenizer(TokenizerLike):
|
|||||||
# Sort the dict for convenience
|
# Sort the dict for convenience
|
||||||
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
|
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
|
||||||
|
|
||||||
|
# Vocab sorted by token id.
|
||||||
|
self._vocab = self.tokenizer.vocab()
|
||||||
|
self._max_token_id = self.vocab_size - 1
|
||||||
|
|
||||||
# Cache special tokens for faster access.
|
# Cache special tokens for faster access.
|
||||||
self._special_token_ids = self._get_special_token_ids()
|
self._special_token_ids = self._get_special_token_ids()
|
||||||
self._special_token_ids_set = set(self._special_token_ids)
|
self._special_token_ids_set = set(self._special_token_ids)
|
||||||
self._special_tokens = self._get_special_tokens(self._special_token_ids)
|
self._special_tokens = self._get_special_tokens(self._special_token_ids)
|
||||||
self._special_tokens_set = set(self._special_tokens)
|
self._special_tokens_set = set(self._special_tokens)
|
||||||
|
|
||||||
# Vocab sorted by token id.
|
|
||||||
self._vocab = self.tokenizer._vocab
|
|
||||||
self._max_token_id = self.vocab_size - 1
|
|
||||||
|
|
||||||
def _get_special_token_ids(self) -> list[int]:
|
def _get_special_token_ids(self) -> list[int]:
|
||||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
return [i for i in range(len(self._vocab)) if self.tokenizer.is_special(i)]
|
||||||
SentencePieceTokenizer,
|
|
||||||
)
|
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
|
||||||
|
|
||||||
if self.is_tekken:
|
|
||||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
|
||||||
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
|
|
||||||
elif self.is_spm:
|
|
||||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
|
|
||||||
self.tokenizer
|
|
||||||
)
|
|
||||||
special_ids = self.tokenizer._control_tokens
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
|
|
||||||
return sorted(special_ids)
|
|
||||||
|
|
||||||
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
|
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
|
||||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
|
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
|
||||||
for i in all_special_ids
|
for i in all_special_ids
|
||||||
@@ -460,15 +442,6 @@ class MistralTokenizer(TokenizerLike):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||||
from mistral_common.tokens.tokenizers.base import (
|
|
||||||
SpecialTokenPolicy,
|
|
||||||
SpecialTokens,
|
|
||||||
)
|
|
||||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
|
||||||
SentencePieceTokenizer,
|
|
||||||
)
|
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
|
||||||
|
|
||||||
to_decode_special_tokens = {SpecialTokens.tool_calls}
|
to_decode_special_tokens = {SpecialTokens.tool_calls}
|
||||||
if self.is_tekken:
|
if self.is_tekken:
|
||||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||||
@@ -523,12 +496,6 @@ class MistralTokenizer(TokenizerLike):
|
|||||||
ids: list[int],
|
ids: list[int],
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
from mistral_common.tokens.tokenizers.base import (
|
|
||||||
SpecialTokenPolicy,
|
|
||||||
SpecialTokens,
|
|
||||||
)
|
|
||||||
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
|
|
||||||
|
|
||||||
if not skip_special_tokens:
|
if not skip_special_tokens:
|
||||||
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
|
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user