Fix some Mistral parser issues (#37209)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
@@ -310,11 +310,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
reasoning_ended = (
|
||||
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
|
||||
if reasoning_parser
|
||||
else None
|
||||
)
|
||||
if not request.include_reasoning:
|
||||
reasoning_ended = True
|
||||
elif reasoning_parser:
|
||||
reasoning_ended = reasoning_parser.is_reasoning_end(
|
||||
prompt_token_ids or []
|
||||
)
|
||||
else:
|
||||
reasoning_ended = None
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
|
||||
@@ -15,8 +15,15 @@ from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
from mistral_common.tokens.tokenizers.base import (
|
||||
SpecialTokenPolicy,
|
||||
SpecialTokens,
|
||||
Tokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.instruct import (
|
||||
InstructTokenizerBase,
|
||||
InstructTokenizerV13,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as MistralCommonTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
@@ -26,21 +33,20 @@ from pydantic import ValidationError
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
|
||||
from .protocol import TokenizerLike
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
except ImportError:
|
||||
# Transformers v4
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as MistralCommonBackend,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import BatchEncoding
|
||||
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
except ImportError:
|
||||
# Transformers v4
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as MistralCommonBackend,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -235,15 +241,6 @@ class MistralTokenizer(TokenizerLike):
|
||||
download_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> "MistralTokenizer":
|
||||
try:
|
||||
# Transformers v5
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
except ImportError:
|
||||
# Transformers v4
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as MistralCommonBackend,
|
||||
)
|
||||
|
||||
tokenizer = MistralCommonBackend.from_pretrained(
|
||||
path_or_repo_id,
|
||||
*args,
|
||||
@@ -255,13 +252,13 @@ class MistralTokenizer(TokenizerLike):
|
||||
|
||||
return cls(tokenizer)
|
||||
|
||||
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
|
||||
def __init__(self, tokenizer: MistralCommonBackend) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.transformers_tokenizer = tokenizer
|
||||
self.mistral = tokenizer.tokenizer
|
||||
self.instruct = self.mistral.instruct_tokenizer
|
||||
self.tokenizer = self.instruct.tokenizer
|
||||
self.transformers_tokenizer: MistralCommonBackend = tokenizer
|
||||
self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
|
||||
self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
|
||||
self.tokenizer: Tokenizer = self.instruct.tokenizer
|
||||
|
||||
mode = self.mistral._chat_completion_request_validator._mode
|
||||
if mode != ValidationMode.test:
|
||||
@@ -483,7 +480,11 @@ class MistralTokenizer(TokenizerLike):
|
||||
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
to_decode_special_tokens = {SpecialTokens.tool_calls}
|
||||
to_decode_special_tokens = {
|
||||
SpecialTokens.tool_calls,
|
||||
SpecialTokens.begin_think,
|
||||
SpecialTokens.end_think,
|
||||
}
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
tokens = [
|
||||
|
||||
@@ -241,7 +241,10 @@ class MistralToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if self.bot_token_id not in current_token_ids:
|
||||
has_bot_token = (
|
||||
self.bot_token_id in current_token_ids or self.bot_token in current_text
|
||||
)
|
||||
if not has_bot_token:
|
||||
# if the tool call token is not in the tokens generated so far,
|
||||
# append output to contents since it's not a tool
|
||||
return DeltaMessage(content=delta_text)
|
||||
@@ -275,7 +278,8 @@ class MistralToolParser(ToolParser):
|
||||
additional_content: str = ""
|
||||
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
|
||||
# this is the first tool call
|
||||
assert self.bot_token_id in delta_token_ids
|
||||
if self.bot_token not in delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
if not delta_text.startswith(self.bot_token):
|
||||
additional_content += delta_text.split(self.bot_token)[0]
|
||||
delta_text = self.bot_token + "".join(
|
||||
@@ -411,7 +415,7 @@ class MistralToolParser(ToolParser):
|
||||
index=self.current_tool_id, type="function"
|
||||
)
|
||||
current_tool_call_modified = False
|
||||
if self.bot_token_id in delta_token_ids:
|
||||
if self.bot_token_id in delta_token_ids or self.bot_token in delta_text:
|
||||
# this is the first tool call
|
||||
if not delta_text.startswith(self.bot_token):
|
||||
content = delta_text.split(self.bot_token)[0]
|
||||
|
||||
Reference in New Issue
Block a user