[V0 deprecation] Guided decoding (#21347)
Signed-off-by: Reza Barazesh <rezabarazesh@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,13 +3,11 @@
|
||||
import copy
|
||||
import json
|
||||
|
||||
import jsonschema
|
||||
import jsonschema.exceptions
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall, MistralToolParser)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
@@ -274,53 +272,6 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
||||
assert parsed_message.content is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("guided_backend",
|
||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||
def test_mistral_guided_decoding(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
vllm_runner,
|
||||
model: str,
|
||||
guided_backend: str,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
# Guided JSON not supported in xgrammar + V1 yet
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype='bfloat16',
|
||||
tokenizer_mode="mistral",
|
||||
guided_decoding_backend=guided_backend,
|
||||
) as vllm_model:
|
||||
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA)
|
||||
params = SamplingParams(max_tokens=512,
|
||||
temperature=0.7,
|
||||
guided_decoding=guided_decoding)
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
|
||||
}]
|
||||
outputs = vllm_model.llm.chat(messages, sampling_params=params)
|
||||
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
json_response = json.loads(generated_text)
|
||||
assert outputs is not None
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=json_response,
|
||||
schema=SAMPLE_JSON_SCHEMA)
|
||||
except jsonschema.exceptions.ValidationError:
|
||||
pytest.fail("Generated response is not valid with JSON schema")
|
||||
|
||||
|
||||
def test_mistral_function_call_nested_json():
|
||||
"""Ensure that the function-name regex captures the entire outer-most
|
||||
JSON block, including nested braces."""
|
||||
|
||||
Reference in New Issue
Block a user