[V0][V1][Core] Add outlines integration for V1, and update V0 integration. (#15975)
Signed-off-by: Nathan Hoos <thwackyy.y@gmail.com>
This commit is contained in:
@@ -41,6 +41,10 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
|
||||
NGRAM_SPEC_CONFIG),
|
||||
#FIXME: This test is flaky on CI thus disabled
|
||||
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
|
||||
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
|
||||
@@ -106,13 +110,15 @@ def test_structured_output(
|
||||
enforce_eager = bool(not current_platform.is_tpu())
|
||||
# Use a single LLM instance for several scenarios to
|
||||
# speed up the test suite.
|
||||
llm = LLM(model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
|
||||
#
|
||||
# Test 1: Generate JSON output based on a provided schema
|
||||
@@ -146,32 +152,33 @@ def test_structured_output(
|
||||
#
|
||||
# Test 2: Generate JSON object without a schema
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
if guided_decoding_backend != "outlines":
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts=(
|
||||
"Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old. "
|
||||
"Make the response as short as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
|
||||
# Parse to verify it is a valid JSON object
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
# Parse to verify it is a valid JSON object
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
#
|
||||
# Test 3: test a jsonschema incompatible with xgrammar
|
||||
@@ -210,97 +217,98 @@ def test_structured_output(
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 5: Generate SQL statement using Lark grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_lark)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 6: Test invalid grammar input
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
if guided_decoding_backend != "outlines":
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 5: Generate SQL statement using Lark grammar
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
prompts=(
|
||||
"Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
"possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_lark)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
# Test 6: Test invalid grammar input
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
prompts=
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short "
|
||||
"as possible."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
#
|
||||
# Test 7: Generate text based on a regex pattern
|
||||
#
|
||||
@@ -421,35 +429,36 @@ def test_structured_output(
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
if guided_decoding_backend != "outlines":
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": False
|
||||
},
|
||||
"additionalProperties": False
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
|
||||
prompt = """
|
||||
prompt = """
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
{
|
||||
@@ -469,7 +478,7 @@ where
|
||||
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name
|
||||
as key and function argument value as value.
|
||||
as key and function argument value as value.
|
||||
end_tag => `</function>`
|
||||
|
||||
Here is an example,
|
||||
@@ -488,37 +497,37 @@ Given the previous instructions, what is the weather in New York City? \
|
||||
Make the response as short as possible.
|
||||
"""
|
||||
|
||||
# Change this once other backends support structural_tag
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
# Change this once other backends support structural_tag
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
|
||||
Reference in New Issue
Block a user