From fef56c18555e881c671acf654630732b7271c14f Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 6 Apr 2026 16:28:51 +0200 Subject: [PATCH] [Mistral Grammar] Support Grammar Factory (#38150) Signed-off-by: juliendenize --- requirements/common.txt | 2 +- requirements/rocm-test.txt | 2 +- requirements/test.txt | 2 +- tests/tokenizers_/test_mistral.py | 28 ++ .../tool_parsers/test_mistral_tool_parser.py | 347 +++++++++++++++++- .../test_backend_guidance.py | 44 +++ vllm/sampling_params.py | 28 +- vllm/tokenizers/mistral.py | 25 ++ vllm/tool_parsers/mistral_tool_parser.py | 142 ++++++- vllm/v1/structured_output/backend_guidance.py | 10 +- 10 files changed, 601 insertions(+), 29 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 05666c5d1..b610fd678 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.17.0 -mistral_common[image] >= 1.10.0 +mistral_common[image] >= 1.11.0 opencv-python-headless >= 4.13.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index d5afde3c8..a441bfef0 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -604,7 +604,7 @@ mcp==1.27.0 # via -r requirements/common.txt mdurl==0.1.2 # via markdown-it-py -mistral-common==1.10.0 +mistral-common==1.11.0 # via # -c requirements/common.txt # -r requirements/common.txt diff --git a/requirements/test.txt b/requirements/test.txt index 642e589a6..c8ff5fcab 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -508,7 +508,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.10.0 +mistral-common==1.11.0 # via # -c requirements/common.txt # -r requirements/test.in diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py index faff61150..2b101e8f9 100644 --- a/tests/tokenizers_/test_mistral.py +++ b/tests/tokenizers_/test_mistral.py @@ -3,8 +3,10 @@ from typing import Any +import llguidance import pytest from mistral_common.exceptions import InvalidMessageStructureException +from mistral_common.guidance.grammar_factory import GrammarFactory from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.tokenizers.mistral import ( @@ -2407,3 +2409,29 @@ class TestMistralTokenizer: assert actual_tokens == expected_tokens assert mistral_tokenizer.convert_ids_to_tokens([]) == [] + + def test_grammar_factory(self, mistral_tokenizer: MistralTokenizer) -> None: + # works in this case cause Mistral 7B is < v11 and SPM + if not mistral_tokenizer.is_tekken: + with pytest.raises(AttributeError): + mistral_tokenizer.grammar_factory # noqa: B018 + return + factory = mistral_tokenizer.grammar_factory + assert isinstance(factory, GrammarFactory) + + # Test caching + factory_2 = mistral_tokenizer.grammar_factory + assert factory is factory_2 + + def test_llg_tokenizer(self, mistral_tokenizer: MistralTokenizer) -> None: + if not mistral_tokenizer.is_tekken: + with pytest.raises(ValueError): + mistral_tokenizer.llg_tokenizer # noqa: B018 + return + + llg_tokenizer = mistral_tokenizer.llg_tokenizer + assert isinstance(llg_tokenizer, llguidance.LLTokenizer) + + # Test caching + llg_tokenizer_2 = mistral_tokenizer.llg_tokenizer + assert llg_tokenizer is llg_tokenizer_2 diff --git a/tests/tool_parsers/test_mistral_tool_parser.py b/tests/tool_parsers/test_mistral_tool_parser.py index 4be564666..064ccb39e 100644 --- a/tests/tool_parsers/test_mistral_tool_parser.py +++ b/tests/tool_parsers/test_mistral_tool_parser.py @@ -3,19 +3,43 @@ import json from collections.abc import Generator +from unittest.mock import MagicMock, patch import partial_json_parser import pytest from mistral_common.protocol.instruct.messages import AssistantMessage from mistral_common.protocol.instruct.request import InstructRequest -from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from mistral_common.protocol.instruct.tool_calls import ( + FunctionCall, + ToolCall, +) +from mistral_common.protocol.instruct.tool_calls import ( + NamedToolChoice as MistralNamedToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoice as MistralToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoiceEnum as MistralToolChoiceEnum, +) from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaMessage, + DeltaToolCall, + StructuralTagResponseFormat, +) +from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers.detokenizer_utils import detokenize_incrementally from vllm.tokenizers.mistral import MistralTokenizer -from vllm.tool_parsers.mistral_tool_parser import MistralToolParser +from vllm.tool_parsers.mistral_tool_parser import ( + _DEFAULT_JSON_SCHEMA, + MistralToolParser, +) @pytest.fixture(scope="module") @@ -40,6 +64,13 @@ def mistral_tool_parser(mistral_tokenizer): return MistralToolParser(mistral_tokenizer) +@pytest.fixture +def non_mistral_parser() -> MistralToolParser: + mock_tokenizer = MagicMock() + mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1} + return MistralToolParser(mock_tokenizer) + + def assert_tool_calls( actual_tool_calls: list[ToolCall] | list[DeltaToolCall], expected_tool_calls: list[ToolCall], @@ -951,3 +982,313 @@ def test_fast_detokenization_text_detection_pre_v11( assert len(delta_message.tool_calls) > 0 assert delta_message.tool_calls[0].function is not None assert delta_message.tool_calls[0].function.name == "add" + + +SAMPLE_TOOLS_DICTS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "add", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + }, + }, +] + + +def _make_request(**kwargs) -> ChatCompletionRequest: + defaults: dict = { + "messages": [], + "model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506", + "tools": SAMPLE_TOOLS_DICTS, + "tool_choice": "auto", + } + defaults.update(kwargs) + return ChatCompletionRequest(**defaults) + + +@pytest.mark.parametrize( + "request_kwargs,expected_mode,expected_parallel", + [ + ({"tool_choice": "auto"}, MistralToolChoiceEnum.auto, True), + ({"tool_choice": "none"}, MistralToolChoiceEnum.none, True), + ({"tool_choice": "required"}, MistralToolChoiceEnum.required, True), + ({"tool_choice": None, "tools": None}, MistralToolChoiceEnum.auto, True), + ( + { + "tool_choice": { + "type": "function", + "function": {"name": "get_weather"}, + } + }, + MistralNamedToolChoice.model_validate( + {"type": "function", "function": {"name": "get_weather"}} + ), + True, + ), + ( + {"tool_choice": "auto", "parallel_tool_calls": False}, + MistralToolChoiceEnum.auto, + False, + ), + ( + {"tool_choice": "auto", "response_format": {"type": "text"}}, + MistralToolChoiceEnum.auto, + True, + ), + ], + ids=[ + "auto", + "none", + "required", + "null_tool_choice", + "named_tool_choice", + "parallel_false", + "response_format_text", + ], +) +def test_adjust_request_grammar_factory( + mistral_tool_parser: MistralToolParser, + request_kwargs: dict, + expected_mode: MistralToolChoice, + expected_parallel: bool, +) -> None: + request = _make_request(**request_kwargs) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + call_kwargs = mock_get_lark.call_args + + assert call_kwargs.kwargs["mode"] == expected_mode + assert call_kwargs.kwargs["json_schema"] is None + assert call_kwargs.kwargs["parallel_tool_calls"] == expected_parallel + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None: + with patch.object( + type(mistral_tokenizer), + "supports_grammar", + new_callable=lambda: property(lambda self: False), + ): + parser = MistralToolParser(mistral_tokenizer) + request = _make_request() + result = parser.adjust_request(request) + + assert result.structured_outputs is None + + +@pytest.mark.parametrize( + "tool_choice,expected_skip", + [("auto", False), ("none", True)], + ids=["auto_skip_false", "none_skip_true"], +) +def test_adjust_request_non_mistral_tokenizer( + non_mistral_parser: MistralToolParser, + tool_choice: str, + expected_skip: bool, +) -> None: + request = _make_request(tool_choice=tool_choice) + result = non_mistral_parser.adjust_request(request) + + assert result.skip_special_tokens is expected_skip + + +@pytest.mark.parametrize( + "so_kwargs", + [ + {"regex": r"\d+"}, + {"choice": ["a", "b"]}, + {"structural_tag": '{"key": "value"}'}, + {"grammar": "start: 'hello'"}, + ], + ids=["regex", "choice", "structural_tag", "grammar"], +) +def test_adjust_request_unsupported_structured_outputs( + mistral_tool_parser: MistralToolParser, + so_kwargs: dict, +) -> None: + request = _make_request( + structured_outputs=StructuredOutputsParams(**so_kwargs), + ) + result = mistral_tool_parser.adjust_request(request) + + assert result.structured_outputs == request.structured_outputs + + +def test_adjust_request_unsupported_response_format( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + response_format=StructuralTagResponseFormat( + type="structural_tag", format={"some": "config"} + ), + ) + result = mistral_tool_parser.adjust_request(request) + assert result.structured_outputs is None + assert result.response_format == request.response_format + + +@pytest.mark.parametrize( + "so_kwargs,expected_json_schema", + [ + ({"json_object": True}, _DEFAULT_JSON_SCHEMA), + ({"json": '{"type": "object"}'}, {"type": "object"}), + ( + {"json": {"type": "object", "properties": {"x": {"type": "integer"}}}}, + {"type": "object", "properties": {"x": {"type": "integer"}}}, + ), + ], + ids=["json_object", "json_str", "json_dict"], +) +def test_adjust_request_structured_outputs_generates_grammar( + mistral_tool_parser: MistralToolParser, + so_kwargs: dict, + expected_json_schema: str, +) -> None: + request = _make_request( + structured_outputs=StructuredOutputsParams(**so_kwargs), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +@pytest.mark.parametrize( + "response_format_kwargs,expected_json_schema", + [ + ({"type": "json_object"}, _DEFAULT_JSON_SCHEMA), + ( + { + "type": "json_schema", + "json_schema": { + "name": "my_schema", + "schema": { + "type": "object", + "properties": {"x": {"type": "integer"}}, + }, + }, + }, + {"type": "object", "properties": {"x": {"type": "integer"}}}, + ), + ], + ids=["json_object", "json_schema_with_schema"], +) +def test_adjust_request_response_format_generates_grammar( + mistral_tool_parser: MistralToolParser, + response_format_kwargs: dict, + expected_json_schema: str, +) -> None: + request = _make_request(response_format=response_format_kwargs) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_get_lark: + result = mistral_tool_parser.adjust_request(request) + + mock_get_lark.assert_called_once() + assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + tool_choice="none", + structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with patch.object( + factory, + "get_lark_for_json_schema", + wraps=factory.get_lark_for_json_schema, + ) as mock_json_schema: + result = mistral_tool_parser.adjust_request(request) + + mock_json_schema.assert_called_once() + assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"} + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 + + +def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory( + mistral_tool_parser: MistralToolParser, +) -> None: + request = _make_request( + tool_choice="auto", + structured_outputs=StructuredOutputsParams(json='{"type": "object"}'), + ) + factory = mistral_tool_parser.model_tokenizer.grammar_factory + + with ( + patch.object( + factory, + "get_lark_for_json_schema", + wraps=factory.get_lark_for_json_schema, + ) as mock_json_schema, + patch.object( + factory, + "get_lark_from_jinja", + wraps=factory.get_lark_from_jinja, + ) as mock_jinja, + ): + result = mistral_tool_parser.adjust_request(request) + + mock_jinja.assert_called_once() + assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"} + mock_json_schema.assert_not_called() + + assert result.structured_outputs is not None + assert isinstance(result.structured_outputs.grammar, str) + assert len(result.structured_outputs.grammar) > 0 diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py index 704ed8b9c..ca8c9b0d7 100644 --- a/tests/v1/structured_output/test_backend_guidance.py +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -11,6 +11,7 @@ from vllm.config.model import ModelConfig from vllm.config.parallel import ParallelConfig from vllm.config.speculative import SpeculativeConfig from vllm.sampling_params import SamplingParams, StructuredOutputsParams +from vllm.tokenizers import get_tokenizer from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output.backend_guidance import GuidanceBackend @@ -19,6 +20,14 @@ from vllm.v1.structured_output.backend_types import StructuredOutputOptions TOKENIZER = "gpt2" +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer( + tokenizer_name="mistralai/Mistral-Small-3.2-24B-Instruct-2506", + tokenizer_mode="mistral", + ) + + def test_backend_guidance_rollback_terminated(): # Test that the backend guidance successfully rollbacks from a # terminated state. This can happen with speculative decoding, @@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar): # Verify the grammar can accept valid tokens assert grammar.accept_tokens(request.request_id, prompt) + + +@pytest.mark.parametrize( + "request_type,grammar_spec", + [ + pytest.param( + StructuredOutputOptions.JSON, + '{"type": "object"}', + id="json", + ), + pytest.param( + StructuredOutputOptions.GRAMMAR, + 'start: "hello" | "world"', + id="lark", + ), + ], +) +def test_mistral_tokenizer_compile_grammar( + mistral_tokenizer, + request_type: StructuredOutputOptions, + grammar_spec: str, +) -> None: + vllm_config = VllmConfig( + structured_outputs_config=StructuredOutputsConfig(backend="guidance"), + ) + backend = GuidanceBackend( + vllm_config, + tokenizer=mistral_tokenizer, + vocab_size=mistral_tokenizer.vocab_size, + ) + assert backend.ll_tokenizer is mistral_tokenizer.llg_tokenizer + + grammar = backend.compile_grammar(request_type, grammar_spec) + assert grammar is not None + assert not grammar.is_terminated() diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 97976b832..9bcc66959 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -153,6 +153,10 @@ class RequestOutputKind(Enum): FINAL_ONLY = 2 +def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool: + return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken + + class SamplingParams( PydanticMsgspecMixin, msgspec.Struct, @@ -801,16 +805,17 @@ class SamplingParams( # xgrammar with no fallback validate_xgrammar_grammar(self) elif backend.startswith("guidance"): + if _is_non_tekken_mistral(tokenizer=tokenizer): + raise ValueError( + "Non-tekken Mistral tokenizers are not supported for the 'guidance'" + " structured output backend. Please either use a more recent " + "Mistral model, the ['xgrammar', 'outlines'] " + "backends or tokenizer_mode='hf' instead." + ) # TODO: ideally we would have the LLTokenizer here as Lark syntax # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. - if is_mistral_tokenizer(tokenizer): - raise ValueError( - "Mistral tokenizer is not supported for the 'guidance' " - "structured output backend. Please use ['xgrammar', 'outlines'] " - "backends or tokenizer_mode='hf' instead." - ) validate_guidance_grammar(self, tokenizer=None) elif backend == "outlines": # outlines backend @@ -839,19 +844,20 @@ class SamplingParams( # or includes some jsonschema feature(s) that # are not supported in xgrammar. + skip_guidance = _is_non_tekken_mistral(tokenizer) + # Check if schema has features unsupported by guidance so_params = self.structured_outputs - skip_guidance = False - if so_params.json: + if not skip_guidance and so_params.json: if isinstance(so_params.json, str): schema = json_mod.loads(so_params.json) else: schema = so_params.json skip_guidance = has_guidance_unsupported_json_features(schema) - if is_mistral_tokenizer(tokenizer) or skip_guidance: - # Fall back to outlines if the tokenizer is Mistral - # or if schema contains features unsupported by guidance + if skip_guidance: + # Fall back to outlines if the tokenizer is non-tekken Mistral or + # the schema contains features unsupported by guidance validate_structured_output_request_outlines(self) self.structured_outputs._backend = "outlines" else: diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index e20f1edd4..147dca888 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, cast, overload +from mistral_common.guidance.grammar_factory import GrammarFactory +from mistral_common.guidance.tokenizer import from_mistral_tokenizer from mistral_common.protocol.instruct.request import ( ChatCompletionRequest as MistralChatCompletionRequest, ) @@ -45,6 +48,7 @@ except ImportError: ) if TYPE_CHECKING: + import llguidance from transformers import BatchEncoding logger = init_logger(__name__) @@ -574,3 +578,24 @@ class MistralTokenizer(TokenizerLike): ] return tokens + + @property + def supports_grammar(self) -> bool: + return GrammarFactory.is_supported(self.mistral) + + @cached_property + def grammar_factory(self) -> GrammarFactory: + if not self.supports_grammar: + raise AttributeError( + "This tokenizer does not support `grammar_factory`. " + "This is only supported for tekken tokenizers with " + "version >= 11." + ) + # Cache grammar factory to avoid creating a llguidance tokenizer at every usage. + return GrammarFactory(self.mistral) + + @cached_property + def llg_tokenizer(self) -> "llguidance.LLTokenizer": + if not self.is_tekken: + raise ValueError("`llg_tokenizer` is only supported for Tekkenizers.") + return from_mistral_tokenizer(self.mistral) diff --git a/vllm/tool_parsers/mistral_tool_parser.py b/vllm/tool_parsers/mistral_tool_parser.py index dc92522a0..4d1aaffed 100644 --- a/vllm/tool_parsers/mistral_tool_parser.py +++ b/vllm/tool_parsers/mistral_tool_parser.py @@ -10,6 +10,18 @@ from typing import Any import ijson import regex as re +from mistral_common.protocol.instruct.tool_calls import ( + NamedToolChoice as MistralNamedToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + Tool as MistralTool, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoice as MistralToolChoice, +) +from mistral_common.protocol.instruct.tool_calls import ( + ToolChoiceEnum as MistralToolChoiceEnum, +) from pydantic import Field from vllm.entrypoints.openai.chat_completion.protocol import ( @@ -25,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger +from vllm.sampling_params import StructuredOutputsParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, @@ -36,6 +49,8 @@ logger = init_logger(__name__) ALPHANUMERIC = ascii_letters + digits +_DEFAULT_JSON_SCHEMA = {"anyOf": [{"type": "object"}, {"type": "array"}]} + class StreamingState(Enum): """Enum for tracking the current streaming parsing state.""" @@ -80,6 +95,9 @@ class MistralToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser mistral are all set """ + # Used to generate correct grammar in `adjust_request` + model_can_reason: bool = False + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) @@ -115,18 +133,124 @@ class MistralToolParser(ToolParser): def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: - request = super().adjust_request(request) + so_non_supported_attributes = [ + "regex", + "choice", + "grammar", + # whitespace_pattern is not a constraint type but an option; + # Mistral grammar factory does not support it. + "whitespace_pattern", + "structural_tag", + ] + any_so_non_supported_active = request.structured_outputs is not None and any( + getattr(request.structured_outputs, attribute) is not None + for attribute in so_non_supported_attributes + ) + response_format_non_supported_active = ( + isinstance(request, ResponsesRequest) + or request.response_format is not None + and request.response_format.type == "structural_tag" + ) + if ( not is_mistral_tokenizer(self.model_tokenizer) - and request.tools - and request.tool_choice != "none" + or isinstance(request, ResponsesRequest) + or not self.model_tokenizer.supports_grammar + or any_so_non_supported_active + or response_format_non_supported_active ): - # Do not skip special tokens when using chat template - # with Mistral parser as TOOL_CALL token is needed - # for tool detection. - # Note: we don't want skip_special_tokens=False - # with MistralTokenizer as it is incompatible - request.skip_special_tokens = False + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # Do not skip special tokens when using chat template + # with Mistral parser as TOOL_CALL token is needed + # for tool detection. + # Note: we don't want skip_special_tokens=False + # with MistralTokenizer as it is incompatible + request.skip_special_tokens = False + return request + + json_schema: dict[str, Any] | None = None + if request.structured_outputs is not None: + if request.structured_outputs.json_object is not None: + json_schema = _DEFAULT_JSON_SCHEMA + elif request.structured_outputs.json is not None: + if isinstance(request.structured_outputs.json, str): + json_schema = json.loads(request.structured_outputs.json) + else: + json_schema = request.structured_outputs.json + else: + raise ValueError( + "Unsupported request.structured_outputs for MistralToolParser. " + "Only `json` and `json_object` are supported." + ) + elif ( + request.response_format is not None + and request.response_format.type != "text" + ): + if request.response_format.type == "json_object": + json_schema = _DEFAULT_JSON_SCHEMA + elif request.response_format.type == "json_schema": + if request.response_format.json_schema is not None: + json_schema = request.response_format.json_schema.json_schema + else: + json_schema = _DEFAULT_JSON_SCHEMA + else: + raise ValueError( + "MistralToolParser only accepts `text`, `json_object` or " + f"`json_schema`, got {request.response_format=}" + ) + # Structured Outputs will be defined. + request.response_format = None + + grammar_factory = self.model_tokenizer.grammar_factory + + # TODO: Once unified parser, improve this. + # The issue is figuring out when a model is a reasoning one or not. + template = grammar_factory.select_jinja_template( + reasoning=self.model_can_reason + ) + + tools = ( + [ + MistralTool.from_openai(openai_tool=tool.model_dump()) + for tool in request.tools + ] + if request.tools is not None + else None + ) + + tool_choice: MistralToolChoice + match request.tool_choice: + case "none" | "auto" | "required": + tool_choice = MistralToolChoiceEnum(request.tool_choice) + case None: + tool_choice = MistralToolChoiceEnum.auto + # _ == Named tool choice + case _: + tool_choice = MistralNamedToolChoice.model_validate( + { + "type": "function", + "function": {"name": request.tool_choice.function.name}, + } + ) + + # Rendering grammar is cached in mistral-common given tools, template and mode. + match tool_choice, json_schema is not None: + case MistralToolChoiceEnum.none, True: + lark_grammar = grammar_factory.get_lark_for_json_schema( + template=template, json_schema=json_schema + ) + case _, _: + lark_grammar = grammar_factory.get_lark_from_jinja( + template=template, + mode=tool_choice, + tools=tools, + json_schema=json_schema, + parallel_tool_calls=request.parallel_tool_calls, + json_only=False, + ) + + request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) return request def extract_tool_calls( diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 6063a2dc2..31178e9f2 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -12,6 +12,7 @@ import torch from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.utils.import_utils import LazyLoader +from vllm.utils.mistral import is_mistral_tokenizer from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, StructuredOutputGrammar, @@ -92,9 +93,12 @@ class GuidanceBackend(StructuredOutputBackend): self.vllm_config.structured_outputs_config.disable_additional_properties ) - self.ll_tokenizer = llguidance_hf.from_tokenizer( - self.tokenizer, max(self.vocab_size, len(self.tokenizer)) - ) + if is_mistral_tokenizer(self.tokenizer): + self.ll_tokenizer = self.tokenizer.llg_tokenizer + else: + self.ll_tokenizer = llguidance_hf.from_tokenizer( + self.tokenizer, max(self.vocab_size, len(self.tokenizer)) + ) def compile_grammar( self, request_type: StructuredOutputOptions, grammar_spec: str