make sure mistral_common not imported for non-mistral models (#12669)

When people use deepseek models, they find that they need to solve cv2
version conflict, see https://zhuanlan.zhihu.com/p/21064432691 .

I added the check, and make all imports of `cv2` lazy.

---------

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-02-03 13:40:25 +08:00
committed by GitHub
parent 95460fc513
commit 20579c0fae
4 changed files with 40 additions and 20 deletions

View File

@@ -8,21 +8,18 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokens
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
# yapf: enable
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)
from vllm.logger import init_logger
from vllm.utils import is_list_of
if TYPE_CHECKING:
# make sure `mistral_common` is lazy imported,
# so that users who only use non-mistral models
# will not be bothered by the dependency.
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
logger = init_logger(__name__)
@@ -33,7 +30,7 @@ class Encoding:
input_ids: Union[List[int], List[List[int]]]
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
@@ -108,12 +105,16 @@ def find_tokenizer_file(files: List[str]):
class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
from mistral_common.tokens.tokenizers.tekken import (
SpecialTokenPolicy, Tekkenizer)
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
if self.is_tekken:
# Make sure special tokens will not raise
@@ -153,6 +154,8 @@ class MistralTokenizer:
assert Path(
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
return cls(mistral_tokenizer)
@@ -181,6 +184,8 @@ class MistralTokenizer:
# by the guided structured output backends.
@property
def all_special_tokens_extended(self) -> List[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
# tekken defines its own extended special tokens list
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
special_tokens = self.tokenizer.SPECIAL_TOKENS
@@ -284,6 +289,8 @@ class MistralTokenizer:
if last_message["role"] == "assistant":
last_message["prefix"] = True
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest)
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)
@@ -292,6 +299,7 @@ class MistralTokenizer:
return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str:
from mistral_common.tokens.tokenizers.base import SpecialTokens
if self.is_tekken:
tokens = [
t for t in tokens
@@ -363,6 +371,8 @@ class MistralTokenizer:
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens