diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 2eb550c3e..ad7982b61 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -310,11 +310,14 @@ class OpenAIServingChat(OpenAIServing): trace_headers=trace_headers, ) else: - reasoning_ended = ( - reasoning_parser.is_reasoning_end(prompt_token_ids or []) - if reasoning_parser - else None - ) + if not request.include_reasoning: + reasoning_ended = True + elif reasoning_parser: + reasoning_ended = reasoning_parser.is_reasoning_end( + prompt_token_ids or [] + ) + else: + reasoning_ended = None generator = self.engine_client.generate( engine_prompt, diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index ca61edeb8..e20f1edd4 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -15,8 +15,15 @@ from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.tokens.tokenizers.base import ( SpecialTokenPolicy, SpecialTokens, + Tokenizer, +) +from mistral_common.tokens.tokenizers.instruct import ( + InstructTokenizerBase, + InstructTokenizerV13, +) +from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as MistralCommonTokenizer, ) -from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) @@ -26,21 +33,20 @@ from pydantic import ValidationError from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.logger import init_logger +from vllm.tokenizers.protocol import TokenizerLike -from .protocol import TokenizerLike +try: + # Transformers v5 + from transformers.tokenization_mistral_common import MistralCommonBackend +except ImportError: + # Transformers v4 + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as MistralCommonBackend, + ) if TYPE_CHECKING: from transformers import BatchEncoding - try: - # Transformers v5 - from transformers.tokenization_mistral_common import MistralCommonBackend - except ImportError: - # Transformers v4 - from transformers.tokenization_mistral_common import ( - MistralCommonTokenizer as MistralCommonBackend, - ) - logger = init_logger(__name__) @@ -235,15 +241,6 @@ class MistralTokenizer(TokenizerLike): download_dir: str | None = None, **kwargs, ) -> "MistralTokenizer": - try: - # Transformers v5 - from transformers.tokenization_mistral_common import MistralCommonBackend - except ImportError: - # Transformers v4 - from transformers.tokenization_mistral_common import ( - MistralCommonTokenizer as MistralCommonBackend, - ) - tokenizer = MistralCommonBackend.from_pretrained( path_or_repo_id, *args, @@ -255,13 +252,13 @@ class MistralTokenizer(TokenizerLike): return cls(tokenizer) - def __init__(self, tokenizer: "MistralCommonBackend") -> None: + def __init__(self, tokenizer: MistralCommonBackend) -> None: super().__init__() - self.transformers_tokenizer = tokenizer - self.mistral = tokenizer.tokenizer - self.instruct = self.mistral.instruct_tokenizer - self.tokenizer = self.instruct.tokenizer + self.transformers_tokenizer: MistralCommonBackend = tokenizer + self.mistral: MistralCommonTokenizer = tokenizer.tokenizer + self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer + self.tokenizer: Tokenizer = self.instruct.tokenizer mode = self.mistral._chat_completion_request_validator._mode if mode != ValidationMode.test: @@ -483,7 +480,11 @@ class MistralTokenizer(TokenizerLike): return self.transformers_tokenizer.convert_tokens_to_ids(tokens) def convert_tokens_to_string(self, tokens: list[str]) -> str: - to_decode_special_tokens = {SpecialTokens.tool_calls} + to_decode_special_tokens = { + SpecialTokens.tool_calls, + SpecialTokens.begin_think, + SpecialTokens.end_think, + } if self.is_tekken: assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) tokens = [ diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index baab4ade0..56ba245ce 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -241,7 +241,10 @@ class MistralToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: - if self.bot_token_id not in current_token_ids: + has_bot_token = ( + self.bot_token_id in current_token_ids or self.bot_token in current_text + ) + if not has_bot_token: # if the tool call token is not in the tokens generated so far, # append output to contents since it's not a tool return DeltaMessage(content=delta_text) @@ -275,7 +278,8 @@ class MistralToolParser(ToolParser): additional_content: str = "" if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: # this is the first tool call - assert self.bot_token_id in delta_token_ids + if self.bot_token not in delta_text: + return DeltaMessage(content=delta_text) if not delta_text.startswith(self.bot_token): additional_content += delta_text.split(self.bot_token)[0] delta_text = self.bot_token + "".join( @@ -411,7 +415,7 @@ class MistralToolParser(ToolParser): index=self.current_tool_id, type="function" ) current_tool_call_modified = False - if self.bot_token_id in delta_token_ids: + if self.bot_token_id in delta_token_ids or self.bot_token in delta_text: # this is the first tool call if not delta_text.startswith(self.bot_token): content = delta_text.split(self.bot_token)[0]