[Mistral Grammar] Support Grammar Factory (#38150)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.17.0
|
||||
mistral_common[image] >= 1.10.0
|
||||
mistral_common[image] >= 1.11.0
|
||||
opencv-python-headless >= 4.13.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
|
||||
@@ -604,7 +604,7 @@ mcp==1.27.0
|
||||
# via -r requirements/common.txt
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.10.0
|
||||
mistral-common==1.11.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/common.txt
|
||||
|
||||
@@ -508,7 +508,7 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.10.0
|
||||
mistral-common==1.11.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test.in
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import llguidance
|
||||
import pytest
|
||||
from mistral_common.exceptions import InvalidMessageStructureException
|
||||
from mistral_common.guidance.grammar_factory import GrammarFactory
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from vllm.tokenizers.mistral import (
|
||||
@@ -2407,3 +2409,29 @@ class TestMistralTokenizer:
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
assert mistral_tokenizer.convert_ids_to_tokens([]) == []
|
||||
|
||||
def test_grammar_factory(self, mistral_tokenizer: MistralTokenizer) -> None:
|
||||
# works in this case cause Mistral 7B is < v11 and SPM
|
||||
if not mistral_tokenizer.is_tekken:
|
||||
with pytest.raises(AttributeError):
|
||||
mistral_tokenizer.grammar_factory # noqa: B018
|
||||
return
|
||||
factory = mistral_tokenizer.grammar_factory
|
||||
assert isinstance(factory, GrammarFactory)
|
||||
|
||||
# Test caching
|
||||
factory_2 = mistral_tokenizer.grammar_factory
|
||||
assert factory is factory_2
|
||||
|
||||
def test_llg_tokenizer(self, mistral_tokenizer: MistralTokenizer) -> None:
|
||||
if not mistral_tokenizer.is_tekken:
|
||||
with pytest.raises(ValueError):
|
||||
mistral_tokenizer.llg_tokenizer # noqa: B018
|
||||
return
|
||||
|
||||
llg_tokenizer = mistral_tokenizer.llg_tokenizer
|
||||
assert isinstance(llg_tokenizer, llguidance.LLTokenizer)
|
||||
|
||||
# Test caching
|
||||
llg_tokenizer_2 = mistral_tokenizer.llg_tokenizer
|
||||
assert llg_tokenizer is llg_tokenizer_2
|
||||
|
||||
@@ -3,19 +3,43 @@
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import partial_json_parser
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import AssistantMessage
|
||||
from mistral_common.protocol.instruct.request import InstructRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
NamedToolChoice as MistralNamedToolChoice,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
ToolChoice as MistralToolChoice,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
ToolChoiceEnum as MistralToolChoiceEnum,
|
||||
)
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
StructuralTagResponseFormat,
|
||||
)
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
|
||||
from vllm.tool_parsers.mistral_tool_parser import (
|
||||
_DEFAULT_JSON_SCHEMA,
|
||||
MistralToolParser,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -40,6 +64,13 @@ def mistral_tool_parser(mistral_tokenizer):
|
||||
return MistralToolParser(mistral_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_mistral_parser() -> MistralToolParser:
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1}
|
||||
return MistralToolParser(mock_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
|
||||
expected_tool_calls: list[ToolCall],
|
||||
@@ -951,3 +982,313 @@ def test_fast_detokenization_text_detection_pre_v11(
|
||||
assert len(delta_message.tool_calls) > 0
|
||||
assert delta_message.tool_calls[0].function is not None
|
||||
assert delta_message.tool_calls[0].function.name == "add"
|
||||
|
||||
|
||||
SAMPLE_TOOLS_DICTS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add",
|
||||
"description": "Add two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _make_request(**kwargs) -> ChatCompletionRequest:
|
||||
defaults: dict = {
|
||||
"messages": [],
|
||||
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
"tools": SAMPLE_TOOLS_DICTS,
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ChatCompletionRequest(**defaults)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_kwargs,expected_mode,expected_parallel",
|
||||
[
|
||||
({"tool_choice": "auto"}, MistralToolChoiceEnum.auto, True),
|
||||
({"tool_choice": "none"}, MistralToolChoiceEnum.none, True),
|
||||
({"tool_choice": "required"}, MistralToolChoiceEnum.required, True),
|
||||
({"tool_choice": None, "tools": None}, MistralToolChoiceEnum.auto, True),
|
||||
(
|
||||
{
|
||||
"tool_choice": {
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather"},
|
||||
}
|
||||
},
|
||||
MistralNamedToolChoice.model_validate(
|
||||
{"type": "function", "function": {"name": "get_weather"}}
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
{"tool_choice": "auto", "parallel_tool_calls": False},
|
||||
MistralToolChoiceEnum.auto,
|
||||
False,
|
||||
),
|
||||
(
|
||||
{"tool_choice": "auto", "response_format": {"type": "text"}},
|
||||
MistralToolChoiceEnum.auto,
|
||||
True,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"auto",
|
||||
"none",
|
||||
"required",
|
||||
"null_tool_choice",
|
||||
"named_tool_choice",
|
||||
"parallel_false",
|
||||
"response_format_text",
|
||||
],
|
||||
)
|
||||
def test_adjust_request_grammar_factory(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
request_kwargs: dict,
|
||||
expected_mode: MistralToolChoice,
|
||||
expected_parallel: bool,
|
||||
) -> None:
|
||||
request = _make_request(**request_kwargs)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_get_lark:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_get_lark.assert_called_once()
|
||||
call_kwargs = mock_get_lark.call_args
|
||||
|
||||
assert call_kwargs.kwargs["mode"] == expected_mode
|
||||
assert call_kwargs.kwargs["json_schema"] is None
|
||||
assert call_kwargs.kwargs["parallel_tool_calls"] == expected_parallel
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
def test_adjust_request_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None:
|
||||
with patch.object(
|
||||
type(mistral_tokenizer),
|
||||
"supports_grammar",
|
||||
new_callable=lambda: property(lambda self: False),
|
||||
):
|
||||
parser = MistralToolParser(mistral_tokenizer)
|
||||
request = _make_request()
|
||||
result = parser.adjust_request(request)
|
||||
|
||||
assert result.structured_outputs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_choice,expected_skip",
|
||||
[("auto", False), ("none", True)],
|
||||
ids=["auto_skip_false", "none_skip_true"],
|
||||
)
|
||||
def test_adjust_request_non_mistral_tokenizer(
|
||||
non_mistral_parser: MistralToolParser,
|
||||
tool_choice: str,
|
||||
expected_skip: bool,
|
||||
) -> None:
|
||||
request = _make_request(tool_choice=tool_choice)
|
||||
result = non_mistral_parser.adjust_request(request)
|
||||
|
||||
assert result.skip_special_tokens is expected_skip
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"so_kwargs",
|
||||
[
|
||||
{"regex": r"\d+"},
|
||||
{"choice": ["a", "b"]},
|
||||
{"structural_tag": '{"key": "value"}'},
|
||||
{"grammar": "start: 'hello'"},
|
||||
],
|
||||
ids=["regex", "choice", "structural_tag", "grammar"],
|
||||
)
|
||||
def test_adjust_request_unsupported_structured_outputs(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
so_kwargs: dict,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
structured_outputs=StructuredOutputsParams(**so_kwargs),
|
||||
)
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
assert result.structured_outputs == request.structured_outputs
|
||||
|
||||
|
||||
def test_adjust_request_unsupported_response_format(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
response_format=StructuralTagResponseFormat(
|
||||
type="structural_tag", format={"some": "config"}
|
||||
),
|
||||
)
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
assert result.structured_outputs is None
|
||||
assert result.response_format == request.response_format
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"so_kwargs,expected_json_schema",
|
||||
[
|
||||
({"json_object": True}, _DEFAULT_JSON_SCHEMA),
|
||||
({"json": '{"type": "object"}'}, {"type": "object"}),
|
||||
(
|
||||
{"json": {"type": "object", "properties": {"x": {"type": "integer"}}}},
|
||||
{"type": "object", "properties": {"x": {"type": "integer"}}},
|
||||
),
|
||||
],
|
||||
ids=["json_object", "json_str", "json_dict"],
|
||||
)
|
||||
def test_adjust_request_structured_outputs_generates_grammar(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
so_kwargs: dict,
|
||||
expected_json_schema: str,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
structured_outputs=StructuredOutputsParams(**so_kwargs),
|
||||
)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_get_lark:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_get_lark.assert_called_once()
|
||||
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_format_kwargs,expected_json_schema",
|
||||
[
|
||||
({"type": "json_object"}, _DEFAULT_JSON_SCHEMA),
|
||||
(
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "my_schema",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{"type": "object", "properties": {"x": {"type": "integer"}}},
|
||||
),
|
||||
],
|
||||
ids=["json_object", "json_schema_with_schema"],
|
||||
)
|
||||
def test_adjust_request_response_format_generates_grammar(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
response_format_kwargs: dict,
|
||||
expected_json_schema: str,
|
||||
) -> None:
|
||||
request = _make_request(response_format=response_format_kwargs)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_get_lark:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_get_lark.assert_called_once()
|
||||
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
tool_choice="none",
|
||||
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
|
||||
)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with patch.object(
|
||||
factory,
|
||||
"get_lark_for_json_schema",
|
||||
wraps=factory.get_lark_for_json_schema,
|
||||
) as mock_json_schema:
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_json_schema.assert_called_once()
|
||||
assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"}
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
|
||||
def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory(
|
||||
mistral_tool_parser: MistralToolParser,
|
||||
) -> None:
|
||||
request = _make_request(
|
||||
tool_choice="auto",
|
||||
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
|
||||
)
|
||||
factory = mistral_tool_parser.model_tokenizer.grammar_factory
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
factory,
|
||||
"get_lark_for_json_schema",
|
||||
wraps=factory.get_lark_for_json_schema,
|
||||
) as mock_json_schema,
|
||||
patch.object(
|
||||
factory,
|
||||
"get_lark_from_jinja",
|
||||
wraps=factory.get_lark_from_jinja,
|
||||
) as mock_jinja,
|
||||
):
|
||||
result = mistral_tool_parser.adjust_request(request)
|
||||
|
||||
mock_jinja.assert_called_once()
|
||||
assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"}
|
||||
mock_json_schema.assert_not_called()
|
||||
|
||||
assert result.structured_outputs is not None
|
||||
assert isinstance(result.structured_outputs.grammar, str)
|
||||
assert len(result.structured_outputs.grammar) > 0
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
@@ -19,6 +20,14 @@ from vllm.v1.structured_output.backend_types import StructuredOutputOptions
|
||||
TOKENIZER = "gpt2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
return get_tokenizer(
|
||||
tokenizer_name="mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
tokenizer_mode="mistral",
|
||||
)
|
||||
|
||||
|
||||
def test_backend_guidance_rollback_terminated():
|
||||
# Test that the backend guidance successfully rollbacks from a
|
||||
# terminated state. This can happen with speculative decoding,
|
||||
@@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar):
|
||||
|
||||
# Verify the grammar can accept valid tokens
|
||||
assert grammar.accept_tokens(request.request_id, prompt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_type,grammar_spec",
|
||||
[
|
||||
pytest.param(
|
||||
StructuredOutputOptions.JSON,
|
||||
'{"type": "object"}',
|
||||
id="json",
|
||||
),
|
||||
pytest.param(
|
||||
StructuredOutputOptions.GRAMMAR,
|
||||
'start: "hello" | "world"',
|
||||
id="lark",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mistral_tokenizer_compile_grammar(
|
||||
mistral_tokenizer,
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: str,
|
||||
) -> None:
|
||||
vllm_config = VllmConfig(
|
||||
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
|
||||
)
|
||||
backend = GuidanceBackend(
|
||||
vllm_config,
|
||||
tokenizer=mistral_tokenizer,
|
||||
vocab_size=mistral_tokenizer.vocab_size,
|
||||
)
|
||||
assert backend.ll_tokenizer is mistral_tokenizer.llg_tokenizer
|
||||
|
||||
grammar = backend.compile_grammar(request_type, grammar_spec)
|
||||
assert grammar is not None
|
||||
assert not grammar.is_terminated()
|
||||
|
||||
@@ -153,6 +153,10 @@ class RequestOutputKind(Enum):
|
||||
FINAL_ONLY = 2
|
||||
|
||||
|
||||
def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool:
|
||||
return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken
|
||||
|
||||
|
||||
class SamplingParams(
|
||||
PydanticMsgspecMixin,
|
||||
msgspec.Struct,
|
||||
@@ -801,16 +805,17 @@ class SamplingParams(
|
||||
# xgrammar with no fallback
|
||||
validate_xgrammar_grammar(self)
|
||||
elif backend.startswith("guidance"):
|
||||
if _is_non_tekken_mistral(tokenizer=tokenizer):
|
||||
raise ValueError(
|
||||
"Non-tekken Mistral tokenizers are not supported for the 'guidance'"
|
||||
" structured output backend. Please either use a more recent "
|
||||
"Mistral model, the ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||||
# allows <|special_token|> and similar, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||
# Without tokenizer these are disallowed in grammars.
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError(
|
||||
"Mistral tokenizer is not supported for the 'guidance' "
|
||||
"structured output backend. Please use ['xgrammar', 'outlines'] "
|
||||
"backends or tokenizer_mode='hf' instead."
|
||||
)
|
||||
validate_guidance_grammar(self, tokenizer=None)
|
||||
elif backend == "outlines":
|
||||
# outlines backend
|
||||
@@ -839,19 +844,20 @@ class SamplingParams(
|
||||
# or includes some jsonschema feature(s) that
|
||||
# are not supported in xgrammar.
|
||||
|
||||
skip_guidance = _is_non_tekken_mistral(tokenizer)
|
||||
|
||||
# Check if schema has features unsupported by guidance
|
||||
so_params = self.structured_outputs
|
||||
skip_guidance = False
|
||||
if so_params.json:
|
||||
if not skip_guidance and so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
schema = json_mod.loads(so_params.json)
|
||||
else:
|
||||
schema = so_params.json
|
||||
skip_guidance = has_guidance_unsupported_json_features(schema)
|
||||
|
||||
if is_mistral_tokenizer(tokenizer) or skip_guidance:
|
||||
# Fall back to outlines if the tokenizer is Mistral
|
||||
# or if schema contains features unsupported by guidance
|
||||
if skip_guidance:
|
||||
# Fall back to outlines if the tokenizer is non-tekken Mistral or
|
||||
# the schema contains features unsupported by guidance
|
||||
validate_structured_output_request_outlines(self)
|
||||
self.structured_outputs._backend = "outlines"
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast, overload
|
||||
|
||||
from mistral_common.guidance.grammar_factory import GrammarFactory
|
||||
from mistral_common.guidance.tokenizer import from_mistral_tokenizer
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest as MistralChatCompletionRequest,
|
||||
)
|
||||
@@ -45,6 +48,7 @@ except ImportError:
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import llguidance
|
||||
from transformers import BatchEncoding
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -574,3 +578,24 @@ class MistralTokenizer(TokenizerLike):
|
||||
]
|
||||
|
||||
return tokens
|
||||
|
||||
@property
|
||||
def supports_grammar(self) -> bool:
|
||||
return GrammarFactory.is_supported(self.mistral)
|
||||
|
||||
@cached_property
|
||||
def grammar_factory(self) -> GrammarFactory:
|
||||
if not self.supports_grammar:
|
||||
raise AttributeError(
|
||||
"This tokenizer does not support `grammar_factory`. "
|
||||
"This is only supported for tekken tokenizers with "
|
||||
"version >= 11."
|
||||
)
|
||||
# Cache grammar factory to avoid creating a llguidance tokenizer at every usage.
|
||||
return GrammarFactory(self.mistral)
|
||||
|
||||
@cached_property
|
||||
def llg_tokenizer(self) -> "llguidance.LLTokenizer":
|
||||
if not self.is_tekken:
|
||||
raise ValueError("`llg_tokenizer` is only supported for Tekkenizers.")
|
||||
return from_mistral_tokenizer(self.mistral)
|
||||
|
||||
@@ -10,6 +10,18 @@ from typing import Any
|
||||
|
||||
import ijson
|
||||
import regex as re
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
NamedToolChoice as MistralNamedToolChoice,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
Tool as MistralTool,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
ToolChoice as MistralToolChoice,
|
||||
)
|
||||
from mistral_common.protocol.instruct.tool_calls import (
|
||||
ToolChoiceEnum as MistralToolChoiceEnum,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
@@ -25,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
@@ -36,6 +49,8 @@ logger = init_logger(__name__)
|
||||
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
_DEFAULT_JSON_SCHEMA = {"anyOf": [{"type": "object"}, {"type": "array"}]}
|
||||
|
||||
|
||||
class StreamingState(Enum):
|
||||
"""Enum for tracking the current streaming parsing state."""
|
||||
@@ -80,6 +95,9 @@ class MistralToolParser(ToolParser):
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
# Used to generate correct grammar in `adjust_request`
|
||||
model_can_reason: bool = False
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
|
||||
@@ -115,12 +133,34 @@ class MistralToolParser(ToolParser):
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
request = super().adjust_request(request)
|
||||
so_non_supported_attributes = [
|
||||
"regex",
|
||||
"choice",
|
||||
"grammar",
|
||||
# whitespace_pattern is not a constraint type but an option;
|
||||
# Mistral grammar factory does not support it.
|
||||
"whitespace_pattern",
|
||||
"structural_tag",
|
||||
]
|
||||
any_so_non_supported_active = request.structured_outputs is not None and any(
|
||||
getattr(request.structured_outputs, attribute) is not None
|
||||
for attribute in so_non_supported_attributes
|
||||
)
|
||||
response_format_non_supported_active = (
|
||||
isinstance(request, ResponsesRequest)
|
||||
or request.response_format is not None
|
||||
and request.response_format.type == "structural_tag"
|
||||
)
|
||||
|
||||
if (
|
||||
not is_mistral_tokenizer(self.model_tokenizer)
|
||||
and request.tools
|
||||
and request.tool_choice != "none"
|
||||
or isinstance(request, ResponsesRequest)
|
||||
or not self.model_tokenizer.supports_grammar
|
||||
or any_so_non_supported_active
|
||||
or response_format_non_supported_active
|
||||
):
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# Do not skip special tokens when using chat template
|
||||
# with Mistral parser as TOOL_CALL token is needed
|
||||
# for tool detection.
|
||||
@@ -129,6 +169,90 @@ class MistralToolParser(ToolParser):
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
json_schema: dict[str, Any] | None = None
|
||||
if request.structured_outputs is not None:
|
||||
if request.structured_outputs.json_object is not None:
|
||||
json_schema = _DEFAULT_JSON_SCHEMA
|
||||
elif request.structured_outputs.json is not None:
|
||||
if isinstance(request.structured_outputs.json, str):
|
||||
json_schema = json.loads(request.structured_outputs.json)
|
||||
else:
|
||||
json_schema = request.structured_outputs.json
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported request.structured_outputs for MistralToolParser. "
|
||||
"Only `json` and `json_object` are supported."
|
||||
)
|
||||
elif (
|
||||
request.response_format is not None
|
||||
and request.response_format.type != "text"
|
||||
):
|
||||
if request.response_format.type == "json_object":
|
||||
json_schema = _DEFAULT_JSON_SCHEMA
|
||||
elif request.response_format.type == "json_schema":
|
||||
if request.response_format.json_schema is not None:
|
||||
json_schema = request.response_format.json_schema.json_schema
|
||||
else:
|
||||
json_schema = _DEFAULT_JSON_SCHEMA
|
||||
else:
|
||||
raise ValueError(
|
||||
"MistralToolParser only accepts `text`, `json_object` or "
|
||||
f"`json_schema`, got {request.response_format=}"
|
||||
)
|
||||
# Structured Outputs will be defined.
|
||||
request.response_format = None
|
||||
|
||||
grammar_factory = self.model_tokenizer.grammar_factory
|
||||
|
||||
# TODO: Once unified parser, improve this.
|
||||
# The issue is figuring out when a model is a reasoning one or not.
|
||||
template = grammar_factory.select_jinja_template(
|
||||
reasoning=self.model_can_reason
|
||||
)
|
||||
|
||||
tools = (
|
||||
[
|
||||
MistralTool.from_openai(openai_tool=tool.model_dump())
|
||||
for tool in request.tools
|
||||
]
|
||||
if request.tools is not None
|
||||
else None
|
||||
)
|
||||
|
||||
tool_choice: MistralToolChoice
|
||||
match request.tool_choice:
|
||||
case "none" | "auto" | "required":
|
||||
tool_choice = MistralToolChoiceEnum(request.tool_choice)
|
||||
case None:
|
||||
tool_choice = MistralToolChoiceEnum.auto
|
||||
# _ == Named tool choice
|
||||
case _:
|
||||
tool_choice = MistralNamedToolChoice.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": request.tool_choice.function.name},
|
||||
}
|
||||
)
|
||||
|
||||
# Rendering grammar is cached in mistral-common given tools, template and mode.
|
||||
match tool_choice, json_schema is not None:
|
||||
case MistralToolChoiceEnum.none, True:
|
||||
lark_grammar = grammar_factory.get_lark_for_json_schema(
|
||||
template=template, json_schema=json_schema
|
||||
)
|
||||
case _, _:
|
||||
lark_grammar = grammar_factory.get_lark_from_jinja(
|
||||
template=template,
|
||||
mode=tool_choice,
|
||||
tools=tools,
|
||||
json_schema=json_schema,
|
||||
parallel_tool_calls=request.parallel_tool_calls,
|
||||
json_only=False,
|
||||
)
|
||||
|
||||
request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar)
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
@@ -92,6 +93,9 @@ class GuidanceBackend(StructuredOutputBackend):
|
||||
self.vllm_config.structured_outputs_config.disable_additional_properties
|
||||
)
|
||||
|
||||
if is_mistral_tokenizer(self.tokenizer):
|
||||
self.ll_tokenizer = self.tokenizer.llg_tokenizer
|
||||
else:
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
self.tokenizer, max(self.vocab_size, len(self.tokenizer))
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user