[Mistral Grammar] Support Grammar Factory (#38150)

Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2026-04-06 16:28:51 +02:00
committed by GitHub
parent c5e3454e5a
commit fef56c1855
10 changed files with 601 additions and 29 deletions

View File

@@ -3,8 +3,10 @@
from typing import Any
import llguidance
import pytest
from mistral_common.exceptions import InvalidMessageStructureException
from mistral_common.guidance.grammar_factory import GrammarFactory
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.tokenizers.mistral import (
@@ -2407,3 +2409,29 @@ class TestMistralTokenizer:
assert actual_tokens == expected_tokens
assert mistral_tokenizer.convert_ids_to_tokens([]) == []
def test_grammar_factory(self, mistral_tokenizer: MistralTokenizer) -> None:
# works in this case cause Mistral 7B is < v11 and SPM
if not mistral_tokenizer.is_tekken:
with pytest.raises(AttributeError):
mistral_tokenizer.grammar_factory # noqa: B018
return
factory = mistral_tokenizer.grammar_factory
assert isinstance(factory, GrammarFactory)
# Test caching
factory_2 = mistral_tokenizer.grammar_factory
assert factory is factory_2
def test_llg_tokenizer(self, mistral_tokenizer: MistralTokenizer) -> None:
if not mistral_tokenizer.is_tekken:
with pytest.raises(ValueError):
mistral_tokenizer.llg_tokenizer # noqa: B018
return
llg_tokenizer = mistral_tokenizer.llg_tokenizer
assert isinstance(llg_tokenizer, llguidance.LLTokenizer)
# Test caching
llg_tokenizer_2 = mistral_tokenizer.llg_tokenizer
assert llg_tokenizer is llg_tokenizer_2

View File

@@ -3,19 +3,43 @@
import json
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import partial_json_parser
import pytest
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import InstructRequest
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from mistral_common.protocol.instruct.tool_calls import (
FunctionCall,
ToolCall,
)
from mistral_common.protocol.instruct.tool_calls import (
NamedToolChoice as MistralNamedToolChoice,
)
from mistral_common.protocol.instruct.tool_calls import (
ToolChoice as MistralToolChoice,
)
from mistral_common.protocol.instruct.tool_calls import (
ToolChoiceEnum as MistralToolChoiceEnum,
)
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
DeltaToolCall,
StructuralTagResponseFormat,
)
from vllm.sampling_params import StructuredOutputsParams
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tool_parsers.mistral_tool_parser import (
_DEFAULT_JSON_SCHEMA,
MistralToolParser,
)
@pytest.fixture(scope="module")
@@ -40,6 +64,13 @@ def mistral_tool_parser(mistral_tokenizer):
return MistralToolParser(mistral_tokenizer)
@pytest.fixture
def non_mistral_parser() -> MistralToolParser:
mock_tokenizer = MagicMock()
mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1}
return MistralToolParser(mock_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
expected_tool_calls: list[ToolCall],
@@ -951,3 +982,313 @@ def test_fast_detokenization_text_detection_pre_v11(
assert len(delta_message.tool_calls) > 0
assert delta_message.tool_calls[0].function is not None
assert delta_message.tool_calls[0].function.name == "add"
SAMPLE_TOOLS_DICTS = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
},
{
"type": "function",
"function": {
"name": "add",
"description": "Add two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
"required": ["a", "b"],
},
},
},
]
def _make_request(**kwargs) -> ChatCompletionRequest:
defaults: dict = {
"messages": [],
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
"tools": SAMPLE_TOOLS_DICTS,
"tool_choice": "auto",
}
defaults.update(kwargs)
return ChatCompletionRequest(**defaults)
@pytest.mark.parametrize(
"request_kwargs,expected_mode,expected_parallel",
[
({"tool_choice": "auto"}, MistralToolChoiceEnum.auto, True),
({"tool_choice": "none"}, MistralToolChoiceEnum.none, True),
({"tool_choice": "required"}, MistralToolChoiceEnum.required, True),
({"tool_choice": None, "tools": None}, MistralToolChoiceEnum.auto, True),
(
{
"tool_choice": {
"type": "function",
"function": {"name": "get_weather"},
}
},
MistralNamedToolChoice.model_validate(
{"type": "function", "function": {"name": "get_weather"}}
),
True,
),
(
{"tool_choice": "auto", "parallel_tool_calls": False},
MistralToolChoiceEnum.auto,
False,
),
(
{"tool_choice": "auto", "response_format": {"type": "text"}},
MistralToolChoiceEnum.auto,
True,
),
],
ids=[
"auto",
"none",
"required",
"null_tool_choice",
"named_tool_choice",
"parallel_false",
"response_format_text",
],
)
def test_adjust_request_grammar_factory(
mistral_tool_parser: MistralToolParser,
request_kwargs: dict,
expected_mode: MistralToolChoice,
expected_parallel: bool,
) -> None:
request = _make_request(**request_kwargs)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_get_lark:
result = mistral_tool_parser.adjust_request(request)
mock_get_lark.assert_called_once()
call_kwargs = mock_get_lark.call_args
assert call_kwargs.kwargs["mode"] == expected_mode
assert call_kwargs.kwargs["json_schema"] is None
assert call_kwargs.kwargs["parallel_tool_calls"] == expected_parallel
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
def test_adjust_request_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None:
with patch.object(
type(mistral_tokenizer),
"supports_grammar",
new_callable=lambda: property(lambda self: False),
):
parser = MistralToolParser(mistral_tokenizer)
request = _make_request()
result = parser.adjust_request(request)
assert result.structured_outputs is None
@pytest.mark.parametrize(
"tool_choice,expected_skip",
[("auto", False), ("none", True)],
ids=["auto_skip_false", "none_skip_true"],
)
def test_adjust_request_non_mistral_tokenizer(
non_mistral_parser: MistralToolParser,
tool_choice: str,
expected_skip: bool,
) -> None:
request = _make_request(tool_choice=tool_choice)
result = non_mistral_parser.adjust_request(request)
assert result.skip_special_tokens is expected_skip
@pytest.mark.parametrize(
"so_kwargs",
[
{"regex": r"\d+"},
{"choice": ["a", "b"]},
{"structural_tag": '{"key": "value"}'},
{"grammar": "start: 'hello'"},
],
ids=["regex", "choice", "structural_tag", "grammar"],
)
def test_adjust_request_unsupported_structured_outputs(
mistral_tool_parser: MistralToolParser,
so_kwargs: dict,
) -> None:
request = _make_request(
structured_outputs=StructuredOutputsParams(**so_kwargs),
)
result = mistral_tool_parser.adjust_request(request)
assert result.structured_outputs == request.structured_outputs
def test_adjust_request_unsupported_response_format(
mistral_tool_parser: MistralToolParser,
) -> None:
request = _make_request(
response_format=StructuralTagResponseFormat(
type="structural_tag", format={"some": "config"}
),
)
result = mistral_tool_parser.adjust_request(request)
assert result.structured_outputs is None
assert result.response_format == request.response_format
@pytest.mark.parametrize(
"so_kwargs,expected_json_schema",
[
({"json_object": True}, _DEFAULT_JSON_SCHEMA),
({"json": '{"type": "object"}'}, {"type": "object"}),
(
{"json": {"type": "object", "properties": {"x": {"type": "integer"}}}},
{"type": "object", "properties": {"x": {"type": "integer"}}},
),
],
ids=["json_object", "json_str", "json_dict"],
)
def test_adjust_request_structured_outputs_generates_grammar(
mistral_tool_parser: MistralToolParser,
so_kwargs: dict,
expected_json_schema: str,
) -> None:
request = _make_request(
structured_outputs=StructuredOutputsParams(**so_kwargs),
)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_get_lark:
result = mistral_tool_parser.adjust_request(request)
mock_get_lark.assert_called_once()
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
@pytest.mark.parametrize(
"response_format_kwargs,expected_json_schema",
[
({"type": "json_object"}, _DEFAULT_JSON_SCHEMA),
(
{
"type": "json_schema",
"json_schema": {
"name": "my_schema",
"schema": {
"type": "object",
"properties": {"x": {"type": "integer"}},
},
},
},
{"type": "object", "properties": {"x": {"type": "integer"}}},
),
],
ids=["json_object", "json_schema_with_schema"],
)
def test_adjust_request_response_format_generates_grammar(
mistral_tool_parser: MistralToolParser,
response_format_kwargs: dict,
expected_json_schema: str,
) -> None:
request = _make_request(response_format=response_format_kwargs)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_get_lark:
result = mistral_tool_parser.adjust_request(request)
mock_get_lark.assert_called_once()
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory(
mistral_tool_parser: MistralToolParser,
) -> None:
request = _make_request(
tool_choice="none",
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_for_json_schema",
wraps=factory.get_lark_for_json_schema,
) as mock_json_schema:
result = mistral_tool_parser.adjust_request(request)
mock_json_schema.assert_called_once()
assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"}
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory(
mistral_tool_parser: MistralToolParser,
) -> None:
request = _make_request(
tool_choice="auto",
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with (
patch.object(
factory,
"get_lark_for_json_schema",
wraps=factory.get_lark_for_json_schema,
) as mock_json_schema,
patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_jinja,
):
result = mistral_tool_parser.adjust_request(request)
mock_jinja.assert_called_once()
assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"}
mock_json_schema.assert_not_called()
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0

View File

@@ -11,6 +11,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import get_tokenizer
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
@@ -19,6 +20,14 @@ from vllm.v1.structured_output.backend_types import StructuredOutputOptions
TOKENIZER = "gpt2"
@pytest.fixture(scope="module")
def mistral_tokenizer():
return get_tokenizer(
tokenizer_name="mistralai/Mistral-Small-3.2-24B-Instruct-2506",
tokenizer_mode="mistral",
)
def test_backend_guidance_rollback_terminated():
# Test that the backend guidance successfully rollbacks from a
# terminated state. This can happen with speculative decoding,
@@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar):
# Verify the grammar can accept valid tokens
assert grammar.accept_tokens(request.request_id, prompt)
@pytest.mark.parametrize(
"request_type,grammar_spec",
[
pytest.param(
StructuredOutputOptions.JSON,
'{"type": "object"}',
id="json",
),
pytest.param(
StructuredOutputOptions.GRAMMAR,
'start: "hello" | "world"',
id="lark",
),
],
)
def test_mistral_tokenizer_compile_grammar(
mistral_tokenizer,
request_type: StructuredOutputOptions,
grammar_spec: str,
) -> None:
vllm_config = VllmConfig(
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
)
backend = GuidanceBackend(
vllm_config,
tokenizer=mistral_tokenizer,
vocab_size=mistral_tokenizer.vocab_size,
)
assert backend.ll_tokenizer is mistral_tokenizer.llg_tokenizer
grammar = backend.compile_grammar(request_type, grammar_spec)
assert grammar is not None
assert not grammar.is_terminated()