[Mistral Grammar] Support Grammar Factory (#38150)

Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2026-04-06 16:28:51 +02:00
committed by GitHub
parent c5e3454e5a
commit fef56c1855
10 changed files with 601 additions and 29 deletions

View File

@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.17.0
mistral_common[image] >= 1.10.0
mistral_common[image] >= 1.11.0
opencv-python-headless >= 4.13.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

View File

@@ -604,7 +604,7 @@ mcp==1.27.0
# via -r requirements/common.txt
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.10.0
mistral-common==1.11.0
# via
# -c requirements/common.txt
# -r requirements/common.txt

View File

@@ -508,7 +508,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.10.0
mistral-common==1.11.0
# via
# -c requirements/common.txt
# -r requirements/test.in

View File

@@ -3,8 +3,10 @@
from typing import Any
import llguidance
import pytest
from mistral_common.exceptions import InvalidMessageStructureException
from mistral_common.guidance.grammar_factory import GrammarFactory
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.tokenizers.mistral import (
@@ -2407,3 +2409,29 @@ class TestMistralTokenizer:
assert actual_tokens == expected_tokens
assert mistral_tokenizer.convert_ids_to_tokens([]) == []
def test_grammar_factory(self, mistral_tokenizer: MistralTokenizer) -> None:
# works in this case cause Mistral 7B is < v11 and SPM
if not mistral_tokenizer.is_tekken:
with pytest.raises(AttributeError):
mistral_tokenizer.grammar_factory # noqa: B018
return
factory = mistral_tokenizer.grammar_factory
assert isinstance(factory, GrammarFactory)
# Test caching
factory_2 = mistral_tokenizer.grammar_factory
assert factory is factory_2
def test_llg_tokenizer(self, mistral_tokenizer: MistralTokenizer) -> None:
if not mistral_tokenizer.is_tekken:
with pytest.raises(ValueError):
mistral_tokenizer.llg_tokenizer # noqa: B018
return
llg_tokenizer = mistral_tokenizer.llg_tokenizer
assert isinstance(llg_tokenizer, llguidance.LLTokenizer)
# Test caching
llg_tokenizer_2 = mistral_tokenizer.llg_tokenizer
assert llg_tokenizer is llg_tokenizer_2

View File

@@ -3,19 +3,43 @@
import json
from collections.abc import Generator
from unittest.mock import MagicMock, patch
import partial_json_parser
import pytest
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import InstructRequest
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from mistral_common.protocol.instruct.tool_calls import (
FunctionCall,
ToolCall,
)
from mistral_common.protocol.instruct.tool_calls import (
NamedToolChoice as MistralNamedToolChoice,
)
from mistral_common.protocol.instruct.tool_calls import (
ToolChoice as MistralToolChoice,
)
from mistral_common.protocol.instruct.tool_calls import (
ToolChoiceEnum as MistralToolChoiceEnum,
)
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
DeltaToolCall,
StructuralTagResponseFormat,
)
from vllm.sampling_params import StructuredOutputsParams
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tool_parsers.mistral_tool_parser import (
_DEFAULT_JSON_SCHEMA,
MistralToolParser,
)
@pytest.fixture(scope="module")
@@ -40,6 +64,13 @@ def mistral_tool_parser(mistral_tokenizer):
return MistralToolParser(mistral_tokenizer)
@pytest.fixture
def non_mistral_parser() -> MistralToolParser:
mock_tokenizer = MagicMock()
mock_tokenizer.get_vocab.return_value = {"[TOOL_CALLS]": 1}
return MistralToolParser(mock_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
expected_tool_calls: list[ToolCall],
@@ -951,3 +982,313 @@ def test_fast_detokenization_text_detection_pre_v11(
assert len(delta_message.tool_calls) > 0
assert delta_message.tool_calls[0].function is not None
assert delta_message.tool_calls[0].function.name == "add"
SAMPLE_TOOLS_DICTS = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
},
{
"type": "function",
"function": {
"name": "add",
"description": "Add two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
"required": ["a", "b"],
},
},
},
]
def _make_request(**kwargs) -> ChatCompletionRequest:
defaults: dict = {
"messages": [],
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
"tools": SAMPLE_TOOLS_DICTS,
"tool_choice": "auto",
}
defaults.update(kwargs)
return ChatCompletionRequest(**defaults)
@pytest.mark.parametrize(
"request_kwargs,expected_mode,expected_parallel",
[
({"tool_choice": "auto"}, MistralToolChoiceEnum.auto, True),
({"tool_choice": "none"}, MistralToolChoiceEnum.none, True),
({"tool_choice": "required"}, MistralToolChoiceEnum.required, True),
({"tool_choice": None, "tools": None}, MistralToolChoiceEnum.auto, True),
(
{
"tool_choice": {
"type": "function",
"function": {"name": "get_weather"},
}
},
MistralNamedToolChoice.model_validate(
{"type": "function", "function": {"name": "get_weather"}}
),
True,
),
(
{"tool_choice": "auto", "parallel_tool_calls": False},
MistralToolChoiceEnum.auto,
False,
),
(
{"tool_choice": "auto", "response_format": {"type": "text"}},
MistralToolChoiceEnum.auto,
True,
),
],
ids=[
"auto",
"none",
"required",
"null_tool_choice",
"named_tool_choice",
"parallel_false",
"response_format_text",
],
)
def test_adjust_request_grammar_factory(
mistral_tool_parser: MistralToolParser,
request_kwargs: dict,
expected_mode: MistralToolChoice,
expected_parallel: bool,
) -> None:
request = _make_request(**request_kwargs)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_get_lark:
result = mistral_tool_parser.adjust_request(request)
mock_get_lark.assert_called_once()
call_kwargs = mock_get_lark.call_args
assert call_kwargs.kwargs["mode"] == expected_mode
assert call_kwargs.kwargs["json_schema"] is None
assert call_kwargs.kwargs["parallel_tool_calls"] == expected_parallel
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
def test_adjust_request_unsupported_grammar_for_tokenizer(mistral_tokenizer) -> None:
with patch.object(
type(mistral_tokenizer),
"supports_grammar",
new_callable=lambda: property(lambda self: False),
):
parser = MistralToolParser(mistral_tokenizer)
request = _make_request()
result = parser.adjust_request(request)
assert result.structured_outputs is None
@pytest.mark.parametrize(
"tool_choice,expected_skip",
[("auto", False), ("none", True)],
ids=["auto_skip_false", "none_skip_true"],
)
def test_adjust_request_non_mistral_tokenizer(
non_mistral_parser: MistralToolParser,
tool_choice: str,
expected_skip: bool,
) -> None:
request = _make_request(tool_choice=tool_choice)
result = non_mistral_parser.adjust_request(request)
assert result.skip_special_tokens is expected_skip
@pytest.mark.parametrize(
"so_kwargs",
[
{"regex": r"\d+"},
{"choice": ["a", "b"]},
{"structural_tag": '{"key": "value"}'},
{"grammar": "start: 'hello'"},
],
ids=["regex", "choice", "structural_tag", "grammar"],
)
def test_adjust_request_unsupported_structured_outputs(
mistral_tool_parser: MistralToolParser,
so_kwargs: dict,
) -> None:
request = _make_request(
structured_outputs=StructuredOutputsParams(**so_kwargs),
)
result = mistral_tool_parser.adjust_request(request)
assert result.structured_outputs == request.structured_outputs
def test_adjust_request_unsupported_response_format(
mistral_tool_parser: MistralToolParser,
) -> None:
request = _make_request(
response_format=StructuralTagResponseFormat(
type="structural_tag", format={"some": "config"}
),
)
result = mistral_tool_parser.adjust_request(request)
assert result.structured_outputs is None
assert result.response_format == request.response_format
@pytest.mark.parametrize(
"so_kwargs,expected_json_schema",
[
({"json_object": True}, _DEFAULT_JSON_SCHEMA),
({"json": '{"type": "object"}'}, {"type": "object"}),
(
{"json": {"type": "object", "properties": {"x": {"type": "integer"}}}},
{"type": "object", "properties": {"x": {"type": "integer"}}},
),
],
ids=["json_object", "json_str", "json_dict"],
)
def test_adjust_request_structured_outputs_generates_grammar(
mistral_tool_parser: MistralToolParser,
so_kwargs: dict,
expected_json_schema: str,
) -> None:
request = _make_request(
structured_outputs=StructuredOutputsParams(**so_kwargs),
)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_get_lark:
result = mistral_tool_parser.adjust_request(request)
mock_get_lark.assert_called_once()
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
@pytest.mark.parametrize(
"response_format_kwargs,expected_json_schema",
[
({"type": "json_object"}, _DEFAULT_JSON_SCHEMA),
(
{
"type": "json_schema",
"json_schema": {
"name": "my_schema",
"schema": {
"type": "object",
"properties": {"x": {"type": "integer"}},
},
},
},
{"type": "object", "properties": {"x": {"type": "integer"}}},
),
],
ids=["json_object", "json_schema_with_schema"],
)
def test_adjust_request_response_format_generates_grammar(
mistral_tool_parser: MistralToolParser,
response_format_kwargs: dict,
expected_json_schema: str,
) -> None:
request = _make_request(response_format=response_format_kwargs)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_get_lark:
result = mistral_tool_parser.adjust_request(request)
mock_get_lark.assert_called_once()
assert mock_get_lark.call_args.kwargs["json_schema"] == expected_json_schema
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
def test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory(
mistral_tool_parser: MistralToolParser,
) -> None:
request = _make_request(
tool_choice="none",
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with patch.object(
factory,
"get_lark_for_json_schema",
wraps=factory.get_lark_for_json_schema,
) as mock_json_schema:
result = mistral_tool_parser.adjust_request(request)
mock_json_schema.assert_called_once()
assert mock_json_schema.call_args.kwargs["json_schema"] == {"type": "object"}
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0
def test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory(
mistral_tool_parser: MistralToolParser,
) -> None:
request = _make_request(
tool_choice="auto",
structured_outputs=StructuredOutputsParams(json='{"type": "object"}'),
)
factory = mistral_tool_parser.model_tokenizer.grammar_factory
with (
patch.object(
factory,
"get_lark_for_json_schema",
wraps=factory.get_lark_for_json_schema,
) as mock_json_schema,
patch.object(
factory,
"get_lark_from_jinja",
wraps=factory.get_lark_from_jinja,
) as mock_jinja,
):
result = mistral_tool_parser.adjust_request(request)
mock_jinja.assert_called_once()
assert mock_jinja.call_args.kwargs["json_schema"] == {"type": "object"}
mock_json_schema.assert_not_called()
assert result.structured_outputs is not None
assert isinstance(result.structured_outputs.grammar, str)
assert len(result.structured_outputs.grammar) > 0

View File

@@ -11,6 +11,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import get_tokenizer
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
@@ -19,6 +20,14 @@ from vllm.v1.structured_output.backend_types import StructuredOutputOptions
TOKENIZER = "gpt2"
@pytest.fixture(scope="module")
def mistral_tokenizer():
return get_tokenizer(
tokenizer_name="mistralai/Mistral-Small-3.2-24B-Instruct-2506",
tokenizer_mode="mistral",
)
def test_backend_guidance_rollback_terminated():
# Test that the backend guidance successfully rollbacks from a
# terminated state. This can happen with speculative decoding,
@@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar):
# Verify the grammar can accept valid tokens
assert grammar.accept_tokens(request.request_id, prompt)
@pytest.mark.parametrize(
"request_type,grammar_spec",
[
pytest.param(
StructuredOutputOptions.JSON,
'{"type": "object"}',
id="json",
),
pytest.param(
StructuredOutputOptions.GRAMMAR,
'start: "hello" | "world"',
id="lark",
),
],
)
def test_mistral_tokenizer_compile_grammar(
mistral_tokenizer,
request_type: StructuredOutputOptions,
grammar_spec: str,
) -> None:
vllm_config = VllmConfig(
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
)
backend = GuidanceBackend(
vllm_config,
tokenizer=mistral_tokenizer,
vocab_size=mistral_tokenizer.vocab_size,
)
assert backend.ll_tokenizer is mistral_tokenizer.llg_tokenizer
grammar = backend.compile_grammar(request_type, grammar_spec)
assert grammar is not None
assert not grammar.is_terminated()

View File

@@ -153,6 +153,10 @@ class RequestOutputKind(Enum):
FINAL_ONLY = 2
def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool:
return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken
class SamplingParams(
PydanticMsgspecMixin,
msgspec.Struct,
@@ -801,16 +805,17 @@ class SamplingParams(
# xgrammar with no fallback
validate_xgrammar_grammar(self)
elif backend.startswith("guidance"):
if _is_non_tekken_mistral(tokenizer=tokenizer):
raise ValueError(
"Non-tekken Mistral tokenizers are not supported for the 'guidance'"
" structured output backend. Please either use a more recent "
"Mistral model, the ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if is_mistral_tokenizer(tokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar(self, tokenizer=None)
elif backend == "outlines":
# outlines backend
@@ -839,19 +844,20 @@ class SamplingParams(
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
skip_guidance = _is_non_tekken_mistral(tokenizer)
# Check if schema has features unsupported by guidance
so_params = self.structured_outputs
skip_guidance = False
if so_params.json:
if not skip_guidance and so_params.json:
if isinstance(so_params.json, str):
schema = json_mod.loads(so_params.json)
else:
schema = so_params.json
skip_guidance = has_guidance_unsupported_json_features(schema)
if is_mistral_tokenizer(tokenizer) or skip_guidance:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
if skip_guidance:
# Fall back to outlines if the tokenizer is non-tekken Mistral or
# the schema contains features unsupported by guidance
validate_structured_output_request_outlines(self)
self.structured_outputs._backend = "outlines"
else:

View File

@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast, overload
from mistral_common.guidance.grammar_factory import GrammarFactory
from mistral_common.guidance.tokenizer import from_mistral_tokenizer
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
)
@@ -45,6 +48,7 @@ except ImportError:
)
if TYPE_CHECKING:
import llguidance
from transformers import BatchEncoding
logger = init_logger(__name__)
@@ -574,3 +578,24 @@ class MistralTokenizer(TokenizerLike):
]
return tokens
@property
def supports_grammar(self) -> bool:
return GrammarFactory.is_supported(self.mistral)
@cached_property
def grammar_factory(self) -> GrammarFactory:
if not self.supports_grammar:
raise AttributeError(
"This tokenizer does not support `grammar_factory`. "
"This is only supported for tekken tokenizers with "
"version >= 11."
)
# Cache grammar factory to avoid creating a llguidance tokenizer at every usage.
return GrammarFactory(self.mistral)
@cached_property
def llg_tokenizer(self) -> "llguidance.LLTokenizer":
if not self.is_tekken:
raise ValueError("`llg_tokenizer` is only supported for Tekkenizers.")
return from_mistral_tokenizer(self.mistral)

View File

@@ -10,6 +10,18 @@ from typing import Any
import ijson
import regex as re
from mistral_common.protocol.instruct.tool_calls import (
NamedToolChoice as MistralNamedToolChoice,
)
from mistral_common.protocol.instruct.tool_calls import (
Tool as MistralTool,
)
from mistral_common.protocol.instruct.tool_calls import (
ToolChoice as MistralToolChoice,
)
from mistral_common.protocol.instruct.tool_calls import (
ToolChoiceEnum as MistralToolChoiceEnum,
)
from pydantic import Field
from vllm.entrypoints.openai.chat_completion.protocol import (
@@ -25,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.sampling_params import StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
@@ -36,6 +49,8 @@ logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
_DEFAULT_JSON_SCHEMA = {"anyOf": [{"type": "object"}, {"type": "array"}]}
class StreamingState(Enum):
"""Enum for tracking the current streaming parsing state."""
@@ -80,6 +95,9 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
# Used to generate correct grammar in `adjust_request`
model_can_reason: bool = False
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
@@ -115,12 +133,34 @@ class MistralToolParser(ToolParser):
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
so_non_supported_attributes = [
"regex",
"choice",
"grammar",
# whitespace_pattern is not a constraint type but an option;
# Mistral grammar factory does not support it.
"whitespace_pattern",
"structural_tag",
]
any_so_non_supported_active = request.structured_outputs is not None and any(
getattr(request.structured_outputs, attribute) is not None
for attribute in so_non_supported_attributes
)
response_format_non_supported_active = (
isinstance(request, ResponsesRequest)
or request.response_format is not None
and request.response_format.type == "structural_tag"
)
if (
not is_mistral_tokenizer(self.model_tokenizer)
and request.tools
and request.tool_choice != "none"
or isinstance(request, ResponsesRequest)
or not self.model_tokenizer.supports_grammar
or any_so_non_supported_active
or response_format_non_supported_active
):
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
@@ -129,6 +169,90 @@ class MistralToolParser(ToolParser):
request.skip_special_tokens = False
return request
json_schema: dict[str, Any] | None = None
if request.structured_outputs is not None:
if request.structured_outputs.json_object is not None:
json_schema = _DEFAULT_JSON_SCHEMA
elif request.structured_outputs.json is not None:
if isinstance(request.structured_outputs.json, str):
json_schema = json.loads(request.structured_outputs.json)
else:
json_schema = request.structured_outputs.json
else:
raise ValueError(
"Unsupported request.structured_outputs for MistralToolParser. "
"Only `json` and `json_object` are supported."
)
elif (
request.response_format is not None
and request.response_format.type != "text"
):
if request.response_format.type == "json_object":
json_schema = _DEFAULT_JSON_SCHEMA
elif request.response_format.type == "json_schema":
if request.response_format.json_schema is not None:
json_schema = request.response_format.json_schema.json_schema
else:
json_schema = _DEFAULT_JSON_SCHEMA
else:
raise ValueError(
"MistralToolParser only accepts `text`, `json_object` or "
f"`json_schema`, got {request.response_format=}"
)
# Structured Outputs will be defined.
request.response_format = None
grammar_factory = self.model_tokenizer.grammar_factory
# TODO: Once unified parser, improve this.
# The issue is figuring out when a model is a reasoning one or not.
template = grammar_factory.select_jinja_template(
reasoning=self.model_can_reason
)
tools = (
[
MistralTool.from_openai(openai_tool=tool.model_dump())
for tool in request.tools
]
if request.tools is not None
else None
)
tool_choice: MistralToolChoice
match request.tool_choice:
case "none" | "auto" | "required":
tool_choice = MistralToolChoiceEnum(request.tool_choice)
case None:
tool_choice = MistralToolChoiceEnum.auto
# _ == Named tool choice
case _:
tool_choice = MistralNamedToolChoice.model_validate(
{
"type": "function",
"function": {"name": request.tool_choice.function.name},
}
)
# Rendering grammar is cached in mistral-common given tools, template and mode.
match tool_choice, json_schema is not None:
case MistralToolChoiceEnum.none, True:
lark_grammar = grammar_factory.get_lark_for_json_schema(
template=template, json_schema=json_schema
)
case _, _:
lark_grammar = grammar_factory.get_lark_from_jinja(
template=template,
mode=tool_choice,
tools=tools,
json_schema=json_schema,
parallel_tool_calls=request.parallel_tool_calls,
json_only=False,
)
request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar)
return request
def extract_tool_calls(
self,
model_output: str,

View File

@@ -12,6 +12,7 @@ import torch
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.utils.import_utils import LazyLoader
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend,
StructuredOutputGrammar,
@@ -92,6 +93,9 @@ class GuidanceBackend(StructuredOutputBackend):
self.vllm_config.structured_outputs_config.disable_additional_properties
)
if is_mistral_tokenizer(self.tokenizer):
self.ll_tokenizer = self.tokenizer.llg_tokenizer
else:
self.ll_tokenizer = llguidance_hf.from_tokenizer(
self.tokenizer, max(self.vocab_size, len(self.tokenizer))
)