make sure mistral_common not imported for non-mistral models (#12669)
When people use deepseek models, they find that they need to solve cv2 version conflict, see https://zhuanlan.zhihu.com/p/21064432691 . I added the check, and make all imports of `cv2` lazy. --------- Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -8,21 +8,18 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
# yapf: disable
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
# yapf: enable
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer)
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||
Tekkenizer)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# make sure `mistral_common` is lazy imported,
|
||||
# so that users who only use non-mistral models
|
||||
# will not be bothered by the dependency.
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -33,7 +30,7 @@ class Encoding:
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
|
||||
|
||||
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
||||
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
||||
# Credits go to: @gcalmettes
|
||||
# NOTE: There is currently a bug in pydantic where attributes
|
||||
@@ -108,12 +105,16 @@ def find_tokenizer_file(files: List[str]):
|
||||
|
||||
class MistralTokenizer:
|
||||
|
||||
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
self.mistral = tokenizer
|
||||
self.instruct = tokenizer.instruct_tokenizer
|
||||
|
||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import (
|
||||
SpecialTokenPolicy, Tekkenizer)
|
||||
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer)
|
||||
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
||||
if self.is_tekken:
|
||||
# Make sure special tokens will not raise
|
||||
@@ -153,6 +154,8 @@ class MistralTokenizer:
|
||||
assert Path(
|
||||
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
|
||||
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
|
||||
return cls(mistral_tokenizer)
|
||||
|
||||
@@ -181,6 +184,8 @@ class MistralTokenizer:
|
||||
# by the guided structured output backends.
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
# tekken defines its own extended special tokens list
|
||||
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
|
||||
special_tokens = self.tokenizer.SPECIAL_TOKENS
|
||||
@@ -284,6 +289,8 @@ class MistralTokenizer:
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest)
|
||||
request = ChatCompletionRequest(messages=messages,
|
||||
tools=tools) # type: ignore[type-var]
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
@@ -292,6 +299,7 @@ class MistralTokenizer:
|
||||
return encoded.tokens
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
if self.is_tekken:
|
||||
tokens = [
|
||||
t for t in tokens
|
||||
@@ -363,6 +371,8 @@ class MistralTokenizer:
|
||||
ids: List[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
assert (
|
||||
skip_special_tokens
|
||||
|
||||
Reference in New Issue
Block a user