[Mistral Grammar] Support Grammar Factory (#38150)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import llguidance
|
||||
import pytest
|
||||
from mistral_common.exceptions import InvalidMessageStructureException
|
||||
from mistral_common.guidance.grammar_factory import GrammarFactory
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from vllm.tokenizers.mistral import (
|
||||
@@ -2407,3 +2409,29 @@ class TestMistralTokenizer:
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
assert mistral_tokenizer.convert_ids_to_tokens([]) == []
|
||||
|
||||
def test_grammar_factory(self, mistral_tokenizer: MistralTokenizer) -> None:
|
||||
# works in this case cause Mistral 7B is < v11 and SPM
|
||||
if not mistral_tokenizer.is_tekken:
|
||||
with pytest.raises(AttributeError):
|
||||
mistral_tokenizer.grammar_factory # noqa: B018
|
||||
return
|
||||
factory = mistral_tokenizer.grammar_factory
|
||||
assert isinstance(factory, GrammarFactory)
|
||||
|
||||
# Test caching
|
||||
factory_2 = mistral_tokenizer.grammar_factory
|
||||
assert factory is factory_2
|
||||
|
||||
def test_llg_tokenizer(self, mistral_tokenizer: MistralTokenizer) -> None:
|
||||
if not mistral_tokenizer.is_tekken:
|
||||
with pytest.raises(ValueError):
|
||||
mistral_tokenizer.llg_tokenizer # noqa: B018
|
||||
return
|
||||
|
||||
llg_tokenizer = mistral_tokenizer.llg_tokenizer
|
||||
assert isinstance(llg_tokenizer, llguidance.LLTokenizer)
|
||||
|
||||
# Test caching
|
||||
llg_tokenizer_2 = mistral_tokenizer.llg_tokenizer
|
||||
assert llg_tokenizer is llg_tokenizer_2
|
||||
|
||||
@@ -3,19 +3,43 @@
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import partial_json_parser
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import AssistantMessage
|
||||
from mistral_common.protocol.instruct.request import InstructRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
NamedToolChoice as MistralNamedToolChoice,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
ToolChoice as MistralToolChoice,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
ToolChoiceEnum as MistralToolChoiceEnum,
|
||||
)
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
StructuralTagResponseFormat,
|
||||
)
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
|
||||
from vllm.tool_parsers.mistral_tool_parser import (
|
||||
_DEFAULT_JSON_SCHEMA,
|
||||
MistralToolParser,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -40,6 +64,13 @@ def mistral_tool_parser(mistral_tokenizer):
|
||||
return MistralToolParser(mistral_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_mistral_parser() -> MistralToolParser:
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1}
|
||||
return MistralToolParser(mock_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
|
||||
expected_tool_calls: list[ToolCall],
|
||||
@@ -951,3 +982,313 @@ def test_fast_detokenization_text_detection_pre_v11(
|
||||
assert len(delta_message.tool_calls) > 0
|
||||
assert delta_message.tool_calls[0].function is not None
|
||||
assert delta_message.tool_calls[0].function.name == "add"
|
||||
|
||||
|
||||
SAMPLE_TOOLS_DICTS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "Add two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _make_request(**kwargs) -> ChatCompletionRequest:
|
||||
defaults: dict = {
|
||||
"messages": [],
|
||||
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
"tools": SAMPLE_TOOLS_DICTS,
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ChatCompletionRequest(**defaults)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_kwargs,expected_mode,expected_parallel",
|
||||
[
|
||||
({"tool_choice": "auto"}, MistralToolChoiceEnum.auto, True),
|
||||
({"tool_choice": "none"}, MistralToolChoiceEnum.none, True),
|
||||
({"tool_choice": "required"}, MistralToolChoiceEnum.required, True),
|
||||
({"tool_choice": None, "tools": None}, MistralToolChoiceEnum.auto, True),
|
||||
(
|
||||
{
|
||||
"tool_choice": {
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather"},
|
||||
}
|
||||
},
|
||||
MistralNamedToolChoice.model_validate(
|
||||
{"type": "function", "function": {"name": "get_weather"}}
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
{"tool_choice": "auto", "parallel_tool_calls": False},
|
||||
MistralToolChoiceEnum.auto,
|
||||
False,
|
||||
),
|
||||
(
|
||||
{"tool_choice": "auto", "response_format": {"type": "text"}},
|
||||
MistralToolChoiceEnum.auto,
|
||||
True,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"auto",
|
||||
"none",
|
||||
"required",
|
||||
"null_tool_choice",
|
||||
"named_tool_choice",
|
||||
"parallel_false",
|
||||
"response_format_text",
|
||||
],
|
||||
)
|
||||
def test_adjust_request_grammar_factory(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
request_kwargs: dict,
|
||||
expected_mode: MistralToolChoice,
|
||||
expected_parallel: bool,
|
||||
) -> None:
|
||||
request = _make_request(**request_kwargs)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_get_lark:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_get_lark.assert_called_once()
|
||||
call_kwargs = mock_get_lark.call_args
|
||||
|
||||
assert call_kwargs.kwargs["mode"] == expected_mode
|
||||
assert call_kwargs.kwargs["json_schema"] is None
|
||||
assert call_kwargs.kwargs["parallel_tool_calls"] == expected_parallel
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
def test_adjust_request_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None:
|
||||
with patch.object(
|
||||
type(mistral_tokenizer),
|
||||
"supports_grammar",
|
||||
new_callable=lambda: property(lambda self: False),
|
||||
):
|
||||
parser = MistralToolParser(mistral_tokenizer)
|
||||
request = _make_request()
|
||||
result = parser.adjust_request(request)
|
||||
|
||||
assert result.structured_outputs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_choice,expected_skip",
|
||||
[("auto", False), ("none", True)],
|
||||
ids=["auto_skip_false", "none_skip_true"],
|
||||
)
|
||||
def test_adjust_request_non_mistral_tokenizer(
|
||||
non_mistral_parser: MistralToolParser,
|
||||
tool_choice: str,
|
||||
expected_skip: bool,
|
||||
) -> None:
|
||||
request = _make_request(tool_choice=tool_choice)
|
||||
result = non_mistral_parser.adjust_request(request)
|
||||
|
||||
assert result.skip_special_tokens is expected_skip
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"so_kwargs",
|
||||
[
|
||||
{"regex": r"\d+"},
|
||||
{"choice": ["a", "b"]},
|
||||
{"structural_tag": '{"key": "value"}'},
|
||||
{"grammar": "start: 'hello'"},
|
||||
],
|
||||
ids=["regex", "choice", "structural_tag", "grammar"],
|
||||
)
|
||||
def test_adjust_request_unsupported_structured_outputs(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
so_kwargs: dict,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
structured_outputs=StructuredOutputsParams(**so_kwargs),
|
||||
)
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
assert result.structured_outputs == request.structured_outputs
|
||||
|
||||
|
||||
def test_adjust_request_unsupported_response_format(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
response_format=StructuralTagResponseFormat(
|
||||
type="structural_tag", format={"some": "config"}
|
||||
),
|
||||
)
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
assert result.structured_outputs is None
|
||||
assert result.response_format == request.response_format
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"so_kwargs,expected_json_schema",
|
||||
[
|
||||
({"json_object": True}, _DEFAULT_JSON_SCHEMA),
|
||||
({"json": '{"type": "object"}'}, {"type": "object"}),
|
||||
(
|
||||
{"json": {"type": "object", "properties": {"x": {"type": "integer"}}}},
|
||||
{"type": "object", "properties": {"x": {"type": "integer"}}},
|
||||
),
|
||||
],
|
||||
ids=["json_object", "json_str", "json_dict"],
|
||||
)
|
||||
def test_adjust_request_structured_outputs_generates_grammar(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
so_kwargs: dict,
|
||||
expected_json_schema: str,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
structured_outputs=StructuredOutputsParams(**so_kwargs),
|
||||
)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_get_lark:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_get_lark.assert_called_once()
|
||||
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_format_kwargs,expected_json_schema",
|
||||
[
|
||||
({"type": "json_object"}, _DEFAULT_JSON_SCHEMA),
|
||||
(
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "my_schema",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{"type": "object", "properties": {"x": {"type": "integer"}}},
|
||||
),
|
||||
],
|
||||
ids=["json_object", "json_schema_with_schema"],
|
||||
)
|
||||
def test_adjust_request_response_format_generates_grammar(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
response_format_kwargs: dict,
|
||||
expected_json_schema: str,
|
||||
) -> None:
|
||||
request = _make_request(response_format=response_format_kwargs)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_get_lark:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_get_lark.assert_called_once()
|
||||
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
tool_choice="none",
|
||||
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
|
||||
)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_for_json_schema",
|
||||
wraps=factory.get_lark_for_json_schema,
|
||||
) as mock_json_schema:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_json_schema.assert_called_once()
|
||||
assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"}
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
tool_choice="auto",
|
||||
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
|
||||
)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
factory,
|
||||
"get_lark_for_json_schema",
|
||||
wraps=factory.get_lark_for_json_schema,
|
||||
) as mock_json_schema,
|
||||
patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_jinja,
|
||||
):
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_jinja.assert_called_once()
|
||||
assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"}
|
||||
mock_json_schema.assert_not_called()
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
@@ -19,6 +20,14 @@ from vllm.v1.structured_output.backend_types import StructuredOutputOptions
|
||||
TOKENIZER = "gpt2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
return get_tokenizer(
|
||||
tokenizer_name="mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
tokenizer_mode="mistral",
|
||||
)
|
||||
|
||||
|
||||
def test_backend_guidance_rollback_terminated():
|
||||
# Test that the backend guidance successfully rollbacks from a
|
||||
# terminated state. This can happen with speculative decoding,
|
||||
@@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar):
|
||||
|
||||
# Verify the grammar can accept valid tokens
|
||||
assert grammar.accept_tokens(request.request_id, prompt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_type,grammar_spec",
|
||||
[
|
||||
pytest.param(
|
||||
StructuredOutputOptions.JSON,
|
||||
'{"type": "object"}',
|
||||
id="json",
|
||||
),
|
||||
pytest.param(
|
||||
StructuredOutputOptions.GRAMMAR,
|
||||
'start: "hello" | "world"',
|
||||
id="lark",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mistral_tokenizer_compile_grammar(
|
||||
mistral_tokenizer,
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: str,
|
||||
) -> None:
|
||||
vllm_config = VllmConfig(
|
||||
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
|
||||
)
|
||||
backend = GuidanceBackend(
|
||||
vllm_config,
|
||||
tokenizer=mistral_tokenizer,
|
||||
vocab_size=mistral_tokenizer.vocab_size,
|
||||
)
|
||||
assert backend.ll_tokenizer is mistral_tokenizer.llg_tokenizer
|
||||
|
||||
grammar = backend.compile_grammar(request_type, grammar_spec)
|
||||
assert grammar is not None
|
||||
assert not grammar.is_terminated()
|
||||
|
||||
Reference in New Issue
Block a user