[Bugfix] Fix guided decoding with tokenizer mode mistral (#11046)

This commit is contained in:
Wallas Henrique
2024-12-18 03:34:08 -03:00
committed by GitHub
parent 866fa4550d
commit 8b79f9e107
7 changed files with 217 additions and 52 deletions

View File

@@ -3,17 +3,20 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy
import json
import jsonschema
import jsonschema.exceptions
import pytest
from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...utils import check_logprobs_close
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
]
MISTRAL_FORMAT_MODELS = [
@@ -126,6 +129,45 @@ MSGS = [
}
]
SAMPLE_JSON_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@@ -251,3 +293,43 @@ def test_mistral_function_calling(
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)
generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
assert outputs is not None
try:
jsonschema.validate(instance=json_response,
schema=SAMPLE_JSON_SCHEMA)
except jsonschema.exceptions.ValidationError:
pytest.fail("Generated response is not valid with JSON schema")