Fix some Mistral parser issues (#37209)

Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2026-03-17 01:08:56 +01:00
committed by GitHub
parent 061980c36a
commit 5db91f0aaf
3 changed files with 42 additions and 34 deletions

View File

@@ -310,11 +310,14 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
)
else:
reasoning_ended = (
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
if not request.include_reasoning:
reasoning_ended = True
elif reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
prompt_token_ids or []
)
else:
reasoning_ended = None
generator = self.engine_client.generate(
engine_prompt,

View File

@@ -15,8 +15,15 @@ from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
Tokenizer,
)
from mistral_common.tokens.tokenizers.instruct import (
InstructTokenizerBase,
InstructTokenizerV13,
)
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as MistralCommonTokenizer,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
@@ -26,21 +33,20 @@ from pydantic import ValidationError
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike
from .protocol import TokenizerLike
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
if TYPE_CHECKING:
from transformers import BatchEncoding
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
logger = init_logger(__name__)
@@ -235,15 +241,6 @@ class MistralTokenizer(TokenizerLike):
download_dir: str | None = None,
**kwargs,
) -> "MistralTokenizer":
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
tokenizer = MistralCommonBackend.from_pretrained(
path_or_repo_id,
*args,
@@ -255,13 +252,13 @@ class MistralTokenizer(TokenizerLike):
return cls(tokenizer)
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
def __init__(self, tokenizer: MistralCommonBackend) -> None:
super().__init__()
self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
self.tokenizer = self.instruct.tokenizer
self.transformers_tokenizer: MistralCommonBackend = tokenizer
self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
self.tokenizer: Tokenizer = self.instruct.tokenizer
mode = self.mistral._chat_completion_request_validator._mode
if mode != ValidationMode.test:
@@ -483,7 +480,11 @@ class MistralTokenizer(TokenizerLike):
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
to_decode_special_tokens = {SpecialTokens.tool_calls}
to_decode_special_tokens = {
SpecialTokens.tool_calls,
SpecialTokens.begin_think,
SpecialTokens.end_think,
}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [

View File

@@ -241,7 +241,10 @@ class MistralToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if self.bot_token_id not in current_token_ids:
has_bot_token = (
self.bot_token_id in current_token_ids or self.bot_token in current_text
)
if not has_bot_token:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return DeltaMessage(content=delta_text)
@@ -275,7 +278,8 @@ class MistralToolParser(ToolParser):
additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call
assert self.bot_token_id in delta_token_ids
if self.bot_token not in delta_text:
return DeltaMessage(content=delta_text)
if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join(
@@ -411,7 +415,7 @@ class MistralToolParser(ToolParser):
index=self.current_tool_id, type="function"
)
current_tool_call_modified = False
if self.bot_token_id in delta_token_ids:
if self.bot_token_id in delta_token_ids or self.bot_token in delta_text:
# this is the first tool call
if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0]