2025-05-14 18:45:24 -04:00
# ruff: noqa: E501
2025-03-07 10:19:11 -05:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2025-03-07 10:19:11 -05:00
2025-03-11 22:40:09 -04:00
from __future__ import annotations
2025-03-07 10:19:11 -05:00
import json
2025-03-28 21:14:53 +08:00
from enum import Enum
2025-05-14 18:45:24 -04:00
from typing import TYPE_CHECKING , Any
2025-03-07 10:19:11 -05:00
import jsonschema
import pytest
2025-05-24 07:16:26 +08:00
import regex as re
2025-03-28 21:14:53 +08:00
from pydantic import BaseModel
2025-03-07 10:19:11 -05:00
2025-05-14 18:45:24 -04:00
from tests . reasoning . utils import run_reasoning_extraction
2025-03-07 10:19:11 -05:00
from vllm . entrypoints . llm import LLM
from vllm . outputs import RequestOutput
2025-04-24 03:50:09 -06:00
from vllm . platforms import current_platform
2025-05-14 18:45:24 -04:00
from vllm . reasoning . abs_reasoning_parsers import ReasoningParserManager
2025-03-07 10:19:11 -05:00
from vllm . sampling_params import GuidedDecodingParams , SamplingParams
2025-05-14 18:45:24 -04:00
if TYPE_CHECKING :
from vllm . config import TokenizerMode
2025-04-29 17:02:10 -07:00
NGRAM_SPEC_CONFIG = {
" model " : " [ngram] " ,
" num_speculative_tokens " : 5 ,
" prompt_lookup_max " : 5 ,
" prompt_lookup_min " : 1 ,
}
EAGLE_SPEC_CONFIG = {
" method " : " eagle " ,
" model " : " yuhuili/EAGLE-LLaMA3.1-Instruct-8B " ,
" num_speculative_tokens " : 5 ,
}
2025-03-30 05:20:19 +02:00
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
2025-04-29 17:02:10 -07:00
( " mistralai/Ministral-8B-Instruct-2410 " , " xgrammar " , " auto " , None ) ,
( " mistralai/Ministral-8B-Instruct-2410 " , " guidance " , " auto " , None ) ,
( " mistralai/Ministral-8B-Instruct-2410 " , " xgrammar " , " mistral " , None ) ,
( " Qwen/Qwen2.5-1.5B-Instruct " , " xgrammar " , " auto " , None ) ,
2025-04-01 10:42:34 -07:00
#FIXME: This test is flaky on CI thus disabled
2025-04-29 20:02:23 +01:00
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
2025-04-29 17:02:10 -07:00
( " mistralai/Ministral-8B-Instruct-2410 " , " guidance " , " auto " ,
NGRAM_SPEC_CONFIG ) ,
( " Qwen/Qwen2.5-1.5B-Instruct " , " xgrammar " , " auto " , NGRAM_SPEC_CONFIG ) ,
( " meta-llama/Meta-Llama-3.1-8B-Instruct " , " xgrammar " , " auto " ,
EAGLE_SPEC_CONFIG )
2025-03-28 11:46:45 -04:00
]
2025-03-30 05:20:19 +02:00
PARAMS_MODELS_TOKENIZER_MODE = [
( " mistralai/Ministral-8B-Instruct-2410 " , " auto " ) ,
( " Qwen/Qwen2.5-1.5B-Instruct " , " auto " ) ,
2025-03-14 16:55:18 -04:00
]
2025-03-11 22:40:09 -04:00
2025-03-29 00:10:45 -04:00
class CarType ( str , Enum ) :
sedan = " sedan "
suv = " SUV "
truck = " Truck "
coupe = " Coupe "
class CarDescription ( BaseModel ) :
brand : str
model : str
car_type : CarType
2025-05-12 18:31:54 -04:00
def _load_json ( s : str , backend : str ) - > str :
if backend != " xgrammar " :
return json . loads ( s )
# xgrammar specific workarounds
# https://github.com/mlc-ai/xgrammar/issues/286
s = re . sub ( r ' [ \ x00- \ x1F \ x7F- \ xFF] ' , ' ' , s )
return json . loads ( s )
2025-03-07 10:19:11 -05:00
@pytest.mark.skip_global_cleanup
2025-04-29 17:02:10 -07:00
@pytest.mark.parametrize (
" model_name, guided_decoding_backend, tokenizer_mode, speculative_config " ,
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE )
2025-03-29 00:10:45 -04:00
def test_structured_output (
2025-03-11 22:40:09 -04:00
monkeypatch : pytest . MonkeyPatch ,
sample_json_schema : dict [ str , Any ] ,
2025-03-29 00:10:45 -04:00
unsupported_json_schema : dict [ str , Any ] ,
sample_sql_ebnf : str ,
sample_sql_lark : str ,
sample_regex : str ,
sample_guided_choice : str ,
2025-03-11 22:40:09 -04:00
guided_decoding_backend : str ,
2025-03-30 05:20:19 +02:00
tokenizer_mode : str ,
2025-03-11 22:40:09 -04:00
model_name : str ,
2025-04-29 17:02:10 -07:00
speculative_config : dict [ str , Any ] ,
2025-03-11 22:40:09 -04:00
) :
2025-03-07 10:19:11 -05:00
monkeypatch . setenv ( " VLLM_USE_V1 " , " 1 " )
2025-03-29 00:10:45 -04:00
2025-04-30 21:36:20 -06:00
if current_platform . is_tpu ( ) and speculative_config :
pytest . skip ( " TPU does not support speculative decoding " )
2025-04-24 03:50:09 -06:00
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool ( not current_platform . is_tpu ( ) )
2025-03-29 00:10:45 -04:00
# Use a single LLM instance for several scenarios to
# speed up the test suite.
2025-03-25 00:02:33 -04:00
llm = LLM ( model = model_name ,
2025-04-24 03:50:09 -06:00
enforce_eager = enforce_eager ,
2025-03-25 00:02:33 -04:00
max_model_len = 1024 ,
2025-03-30 05:20:19 +02:00
guided_decoding_backend = guided_decoding_backend ,
2025-04-29 20:02:23 +01:00
guided_decoding_disable_any_whitespace = True ,
2025-04-29 17:02:10 -07:00
tokenizer_mode = tokenizer_mode ,
speculative_config = speculative_config )
2025-03-29 00:10:45 -04:00
#
# Test 1: Generate JSON output based on a provided schema
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 1.0 ,
2025-05-12 18:31:54 -04:00
max_tokens = 4096 ,
2025-03-25 00:02:33 -04:00
guided_decoding = GuidedDecodingParams ( json = sample_json_schema ) )
2025-03-07 10:19:11 -05:00
outputs = llm . generate ( prompts = [
2025-05-08 01:34:02 -04:00
( f " Give an example JSON for an employee profile that fits this "
f " schema. Make the response as short as possible. Schema: "
f " { sample_json_schema } " )
2025-03-07 10:19:11 -05:00
] * 2 ,
sampling_params = sampling_params ,
use_tqdm = True )
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
2025-04-29 20:02:23 +01:00
assert " \n " not in generated_text
2025-03-22 11:56:17 -04:00
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
output_json = json . loads ( generated_text )
jsonschema . validate ( instance = output_json , schema = sample_json_schema )
2025-03-29 00:10:45 -04:00
#
# Test 2: Generate JSON object without a schema
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 1.0 ,
2025-05-12 18:31:54 -04:00
max_tokens = 4096 ,
2025-03-25 00:02:33 -04:00
n = 2 ,
guided_decoding = GuidedDecodingParams ( json_object = True ) )
2025-03-07 10:19:11 -05:00
outputs = llm . generate (
prompts = ( " Generate a JSON object with curly braces for a person with "
2025-05-08 01:34:02 -04:00
" name and age fields for John Smith who is 31 years old. "
" Make the response as short as possible. " ) ,
2025-03-07 10:19:11 -05:00
sampling_params = sampling_params ,
use_tqdm = True )
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
for i in range ( 2 ) :
generated_text = output . outputs [ i ] . text
print ( generated_text )
assert generated_text is not None
2025-04-02 05:00:08 -04:00
# Parse to verify it is a valid JSON object
2025-03-07 10:19:11 -05:00
parsed_json = json . loads ( generated_text )
2025-04-02 05:00:08 -04:00
assert isinstance ( parsed_json , dict )
2025-03-07 10:19:11 -05:00
2025-03-29 00:10:45 -04:00
#
# Test 3: test a jsonschema incompatible with xgrammar
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 1.0 ,
2025-05-12 18:31:54 -04:00
max_tokens = 4096 ,
2025-03-25 00:02:33 -04:00
guided_decoding = GuidedDecodingParams ( json = unsupported_json_schema ) )
2025-03-28 11:46:45 -04:00
if guided_decoding_backend . startswith ( " xgrammar " ) :
2025-03-25 00:02:33 -04:00
with pytest . raises ( ValueError ,
match = " The provided JSON schema contains features "
" not supported by xgrammar. " ) :
2025-05-08 01:34:02 -04:00
llm . generate (
prompts = [ ( f " Give an example JSON for an employee profile that "
f " fits this schema: { unsupported_json_schema } . "
f " Make the response as short as possible. " ) ] * 2 ,
sampling_params = sampling_params ,
use_tqdm = True )
2025-03-25 00:02:33 -04:00
else :
2025-05-08 01:34:02 -04:00
outputs = llm . generate ( prompts = (
" Give an example JSON object for a grade "
" that fits this schema: "
f " { unsupported_json_schema } . Make the response as short as "
" possible. " ) ,
sampling_params = sampling_params ,
use_tqdm = True )
2025-03-25 00:02:33 -04:00
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
print ( generated_text )
# Parse to verify it is valid JSON
parsed_json = json . loads ( generated_text )
assert isinstance ( parsed_json , dict )
2025-03-07 10:19:11 -05:00
2025-03-29 00:10:45 -04:00
#
# Test 4: Generate SQL statement using EBNF grammar
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 0.8 ,
top_p = 0.95 ,
max_tokens = 1000 ,
guided_decoding = GuidedDecodingParams ( grammar = sample_sql_ebnf ) )
2025-03-07 10:19:11 -05:00
outputs = llm . generate (
2025-05-08 01:34:02 -04:00
prompts = (
" Generate a sql statement that selects col_1 from "
" table_1 where it is equal to 1. Make the response as short as "
" possible. " ) ,
2025-03-07 10:19:11 -05:00
sampling_params = sampling_params ,
use_tqdm = True ,
)
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = " SELECT col_1 from table_1 where col_1 = 1 " . replace (
" " , " " )
assert generated_text . strip ( ) == ground_truth
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
2025-03-29 00:10:45 -04:00
#
# Test 5: Generate SQL statement using Lark grammar
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 0.8 ,
top_p = 0.95 ,
max_tokens = 1000 ,
guided_decoding = GuidedDecodingParams ( grammar = sample_sql_lark ) )
2025-03-07 10:19:11 -05:00
outputs = llm . generate (
2025-05-08 01:34:02 -04:00
prompts = (
" Generate a sql statement that selects col_1 from "
" table_1 where it is equal to 1. Make the response as short as "
" possible. " ) ,
2025-03-07 10:19:11 -05:00
sampling_params = sampling_params ,
use_tqdm = True ,
)
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark ( sample_sql_lark )
parser . parse ( generated_text )
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = " SELECT col_1 from table_1 where col_1 = 1 " . replace (
" " , " " )
assert generated_text . strip ( ) == ground_truth
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
2025-03-29 00:10:45 -04:00
#
# Test 6: Test invalid grammar input
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 0.8 ,
top_p = 0.95 ,
max_tokens = 1000 ,
guided_decoding = GuidedDecodingParams ( grammar = " not a grammar " ) )
with pytest . raises ( ValueError , match = " Failed to convert the grammar " ) :
2025-03-07 10:19:11 -05:00
llm . generate (
2025-05-08 01:34:02 -04:00
prompts = (
" Generate a sql statement that selects col_1 from "
" table_1 where it is equal to 1. Make the response as short "
" as possible. " ) ,
2025-03-07 10:19:11 -05:00
sampling_params = sampling_params ,
use_tqdm = True ,
)
2025-03-29 00:10:45 -04:00
#
# Test 7: Generate text based on a regex pattern
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 0.8 ,
top_p = 0.95 ,
guided_decoding = GuidedDecodingParams ( regex = sample_regex ) )
2025-03-11 11:03:44 -04:00
outputs = llm . generate (
prompts = [
2025-05-08 01:34:02 -04:00
( f " Give an example IPv4 address with this regex: { sample_regex } . "
f " Make the response as short as possible. " )
2025-03-07 10:19:11 -05:00
] * 2 ,
2025-03-11 11:03:44 -04:00
sampling_params = sampling_params ,
use_tqdm = True ,
)
2025-03-07 10:19:11 -05:00
2025-03-11 11:03:44 -04:00
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
print ( generated_text )
assert generated_text is not None
assert re . fullmatch ( sample_regex , generated_text ) is not None
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
2025-03-07 10:19:11 -05:00
2025-03-29 00:10:45 -04:00
#
# Test 8: Generate text based on a choices
#
2025-03-25 00:02:33 -04:00
sampling_params = SamplingParams (
temperature = 0.8 ,
top_p = 0.95 ,
guided_decoding = GuidedDecodingParams ( choice = sample_guided_choice ) )
2025-03-07 10:19:11 -05:00
outputs = llm . generate (
2025-05-08 01:34:02 -04:00
prompts = ( " The best language for type-safe systems programming is "
" (Make the response as short as possible.) " ) ,
2025-03-07 10:19:11 -05:00
sampling_params = sampling_params ,
use_tqdm = True )
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
print ( generated_text )
assert generated_text is not None
assert generated_text in sample_guided_choice
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
2025-03-28 21:14:53 +08:00
2025-03-29 00:10:45 -04:00
#
# Test 9: Generate structured output using a Pydantic model with an enum
#
2025-03-28 21:14:53 +08:00
json_schema = CarDescription . model_json_schema ( )
sampling_params = SamplingParams (
temperature = 1.0 ,
max_tokens = 1000 ,
guided_decoding = GuidedDecodingParams ( json = json_schema ) )
2025-05-08 01:34:02 -04:00
outputs = llm . generate ( prompts = (
" Generate a JSON with the brand, model and car_type of the most "
" iconic car from the 90 ' s. Make the response as short as "
" possible. " ) ,
sampling_params = sampling_params ,
use_tqdm = True )
2025-03-28 21:14:53 +08:00
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
output_json = json . loads ( generated_text )
jsonschema . validate ( instance = output_json , schema = json_schema )
2025-03-29 00:10:45 -04:00
2025-04-12 08:22:07 +02:00
#
# Test 10: Generate structured with minLength and maxLength
#
min_length = 50
max_length = 50
json_schema = {
" type " : " object " ,
" properties " : {
" description " : {
" type " : " string " ,
" maxLength " : max_length ,
" minLength " : min_length
}
} ,
2025-05-12 18:31:54 -04:00
" required " : [ " description " ] ,
" additionalProperties " : False
2025-04-12 08:22:07 +02:00
}
sampling_params = SamplingParams (
temperature = 1.0 ,
2025-05-12 18:31:54 -04:00
max_tokens = 4096 ,
2025-04-12 08:22:07 +02:00
guided_decoding = GuidedDecodingParams ( json = json_schema ) )
2025-04-26 10:06:37 -04:00
2025-04-12 08:22:07 +02:00
outputs = llm . generate (
2025-05-08 01:34:02 -04:00
prompts = ( " Generate a description of a frog using 50 characters. "
" Make the response as short as possible. " ) ,
2025-04-12 08:22:07 +02:00
sampling_params = sampling_params ,
use_tqdm = True )
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
print ( f " Prompt: { prompt !r} , Generated text: { generated_text !r} " )
output_json = json . loads ( generated_text )
jsonschema . validate ( instance = output_json , schema = json_schema )
2025-04-26 10:06:37 -04:00
#
# Test 11: Generate structured output using structural_tag format
#
structural_tag_config = {
" type " :
" structural_tag " ,
" structures " : [ {
" begin " : " <function=get_weather> " ,
" schema " : {
" type " : " object " ,
" properties " : {
" city " : {
" type " : " string "
}
2025-05-12 18:31:54 -04:00
} ,
" additionalProperties " : False
2025-04-26 10:06:37 -04:00
} ,
" end " : " </function> "
} ] ,
" triggers " : [ " <function= " ]
}
sampling_params = SamplingParams (
temperature = 0.0 ,
2025-05-12 18:31:54 -04:00
max_tokens = 4096 ,
2025-04-26 10:06:37 -04:00
guided_decoding = GuidedDecodingParams (
structural_tag = json . dumps ( structural_tag_config ) ) )
prompt = """
You have access to the following function to retrieve the weather in a city :
2025-05-14 18:45:24 -04:00
2025-04-26 10:06:37 -04:00
{
" name " : " get_weather " ,
" parameters " : {
" city " : {
" param_type " : " string " ,
" description " : " The city to get the weather for " ,
" required " : True
}
}
}
2025-05-14 18:45:24 -04:00
2025-04-26 10:06:37 -04:00
If a you choose to call a function ONLY reply in the following format :
< { start_tag } = { function_name } > { parameters } { end_tag }
where
start_tag = > ` < function `
parameters = > a JSON dict with the function argument name
as key and function argument value as value .
end_tag = > ` < / function > `
Here is an example ,
< function = example_function_name > { " example_name " : " example_value " } < / function >
Reminder :
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant .
2025-05-14 18:45:24 -04:00
2025-05-08 01:34:02 -04:00
Given the previous instructions , what is the weather in New York City ? \
Make the response as short as possible .
2025-04-26 10:06:37 -04:00
"""
# Change this once other backends support structural_tag
2025-04-28 19:21:32 -07:00
outputs = llm . generate ( prompts = prompt ,
sampling_params = sampling_params ,
use_tqdm = True )
assert outputs is not None
2025-04-26 10:06:37 -04:00
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
# Search for function call pattern in the response
function_call_pattern = r ' <function=get_weather>(.*?)</function> '
matches = re . findall ( function_call_pattern , generated_text )
if not matches :
print ( f " Warning: No function calls found in response: "
f " { generated_text !r} " )
continue
# Take the first function call if multiple are found
json_str = matches [ 0 ]
try :
json_content = json . loads ( json_str )
assert " city " in json_content
assert isinstance ( json_content [ " city " ] , str )
print ( f " Found valid function call: { generated_text !r} " )
except ( json . JSONDecodeError , AssertionError ) as e :
pytest . fail ( " Invalid function call format: "
f " { generated_text !r} \n Error: { str ( e ) } " )
2025-03-29 00:10:45 -04:00
2025-05-14 18:45:24 -04:00
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize (
" model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config " , # noqa: E501
[
( " deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B " , " xgrammar " , " auto " ,
" deepseek_r1 " , NGRAM_SPEC_CONFIG ) ,
( " Qwen/Qwen3-1.7B " , " xgrammar " , " auto " , " deepseek_r1 " , None ) ,
] ,
)
def test_structured_output_with_reasoning_matrices (
monkeypatch : pytest . MonkeyPatch ,
guided_decoding_backend : str ,
tokenizer_mode : TokenizerMode ,
reasoning_parser : str ,
model_name : str ,
speculative_config : dict [ str , Any ] | None ,
) :
monkeypatch . setenv ( " VLLM_USE_V1 " , " 1 " )
if current_platform . is_tpu ( ) and speculative_config :
pytest . skip ( " TPU does not support speculative decoding " )
# Use a single LLM instance for several scenarios to
# speed up the test suite.
llm = LLM (
model = model_name ,
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool ( not current_platform . is_tpu ( ) ) ,
max_model_len = 1024 ,
max_num_seqs = 16 ,
guided_decoding_backend = guided_decoding_backend ,
guided_decoding_disable_any_whitespace = True ,
tokenizer_mode = tokenizer_mode ,
reasoning_parser = reasoning_parser ,
speculative_config = speculative_config ,
)
tokenizer = llm . get_tokenizer ( None )
reasoner = ReasoningParserManager . get_reasoning_parser ( reasoning_parser ) (
tokenizer = tokenizer )
reasoning_prompt = " Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key ' result ' . Make sure to correct your reasoning if there are any issue should it arise. \n Problem: What is 5 * 8 + 2? " # noqa: E501
reasoning_schema = {
" type " : " object " ,
" properties " : {
" result " : {
" type " : " integer "
}
} ,
" required " : [ " result " ] ,
" additionalProperties " : False
}
if " Qwen3 " in model_name :
reasoning_prompt + = " <think> \n "
sampling_params = SamplingParams (
temperature = 0.1 ,
max_tokens = 8192 ,
guided_decoding = GuidedDecodingParams ( json = reasoning_schema ) ,
)
outputs = llm . generate (
[ reasoning_prompt ] ,
sampling_params = sampling_params ,
use_tqdm = True ,
)
assert outputs is not None
output = outputs [ 0 ]
assert output is not None and isinstance ( output , RequestOutput )
prompt = output . prompt
generated_text = output . outputs [ 0 ] . text
reasoning_content , content = run_reasoning_extraction (
reasoner , [ generated_text ] )
print (
f " Prompt: { prompt !r} \n Reasoning: { reasoning_content !r} \n Content: { content !r} "
)
assert content is not None and reasoning_content is not None
output_json = json . loads ( content )
jsonschema . validate ( instance = output_json , schema = reasoning_schema )
2025-03-29 00:10:45 -04:00
@pytest.mark.skip_global_cleanup
2025-03-30 05:20:19 +02:00
@pytest.mark.parametrize ( " model_name, tokenizer_mode " ,
PARAMS_MODELS_TOKENIZER_MODE )
2025-03-29 00:10:45 -04:00
def test_structured_output_auto_mode (
monkeypatch : pytest . MonkeyPatch ,
unsupported_json_schema : dict [ str , Any ] ,
model_name : str ,
2025-03-30 05:20:19 +02:00
tokenizer_mode : str ,
2025-03-29 00:10:45 -04:00
) :
monkeypatch . setenv ( " VLLM_USE_V1 " , " 1 " )
llm = LLM ( model = model_name ,
max_model_len = 1024 ,
2025-03-30 05:20:19 +02:00
guided_decoding_backend = " auto " ,
tokenizer_mode = tokenizer_mode )
2025-03-29 00:10:45 -04:00
sampling_params = SamplingParams (
temperature = 1.0 ,
max_tokens = 1000 ,
guided_decoding = GuidedDecodingParams ( json = unsupported_json_schema ) )
2025-05-08 01:34:02 -04:00
prompts = (
" Give an example JSON object for a grade "
" that fits this schema: "
f " { unsupported_json_schema } . Make the response as short as possible. " )
2025-03-29 00:10:45 -04:00
# This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically.
2025-04-22 14:02:20 +08:00
outputs = llm . generate ( prompts = prompts ,
2025-03-29 00:10:45 -04:00
sampling_params = sampling_params ,
use_tqdm = True )
2025-04-22 14:02:20 +08:00
# Make sure `auto` backend handling doesn't mess up sampling_params
# and that we can reuse it without error.
outputs . extend (
llm . generate ( prompts = prompts ,
sampling_params = sampling_params ,
use_tqdm = True ) )
2025-03-29 00:10:45 -04:00
assert outputs is not None
for output in outputs :
assert output is not None
assert isinstance ( output , RequestOutput )
generated_text = output . outputs [ 0 ] . text
assert generated_text is not None
print ( generated_text )
# Parse to verify it is valid JSON
parsed_json = json . loads ( generated_text )
assert isinstance ( parsed_json , dict )
2025-04-23 12:34:41 -06:00
@pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties ( monkeypatch : pytest . MonkeyPatch ) :
monkeypatch . setenv ( " VLLM_USE_V1 " , " 1 " )
llm = LLM ( model = " Qwen/Qwen2.5-1.5B-Instruct " ,
max_model_len = 1024 ,
2025-04-29 20:02:23 +01:00
guided_decoding_backend = " guidance " ,
guided_decoding_disable_any_whitespace = True ,
guided_decoding_disable_additional_properties = True )
2025-04-23 12:34:41 -06:00
schema = {
' type ' : ' object ' ,
' properties ' : {
' a1 ' : {
' type ' : ' string '
} ,
' a2 ' : {
' type ' : ' string '
} ,
' a3 ' : {
' type ' : ' string '
}
} ,
' required ' : [ ' a1 ' , ' a2 ' , ' a3 ' ] ,
}
prompt = (
" <|im_start|>system \n You are Qwen, created by Alibaba Cloud. You are a "
" helpful assistant.<|im_end|> \n <|im_start|>user \n Please generate a "
2025-05-08 01:34:02 -04:00
" large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. "
" Make the response as short as possible. "
2025-04-23 12:34:41 -06:00
" <|im_end|> \n <|im_start|>assistant \n " )
def generate_with_backend ( backend ) :
2025-04-29 20:02:23 +01:00
guided_params = GuidedDecodingParams (
json = schema ,
backend = backend ,
disable_any_whitespace = True ,
disable_additional_properties = True )
2025-04-23 12:34:41 -06:00
sampling_params = SamplingParams ( temperature = 0 ,
max_tokens = 256 ,
guided_decoding = guided_params )
outputs = llm . generate ( prompts = prompt , sampling_params = sampling_params )
assert outputs is not None
generated_text = outputs [ 0 ] . outputs [ 0 ] . text
assert generated_text is not None
parsed_json = json . loads ( generated_text )
assert isinstance ( parsed_json , dict )
jsonschema . validate ( instance = parsed_json , schema = schema )
return parsed_json
2025-04-29 20:02:23 +01:00
generated = generate_with_backend ( " guidance " )
2025-04-23 12:34:41 -06:00
assert " a1 " in generated
assert " a2 " in generated
assert " a3 " in generated
assert " a4 " not in generated
assert " a5 " not in generated
assert " a6 " not in generated