[V1] Add structural_tag support using xgrammar (#17085)
This commit is contained in:
@@ -350,6 +350,7 @@ def test_structured_output(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a description of a frog using 50 characters.",
|
||||
sampling_params=sampling_params,
|
||||
@@ -368,6 +369,106 @@ def test_structured_output(
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
structural_tag_config = {
|
||||
"type":
|
||||
"structural_tag",
|
||||
"structures": [{
|
||||
"begin": "<function=get_weather>",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
"triggers": ["<function="]
|
||||
}
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=100,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
|
||||
prompt = """
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"city": {
|
||||
"param_type": "string",
|
||||
"description": "The city to get the weather for",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
If a you choose to call a function ONLY reply in the following format:
|
||||
<{start_tag}={function_name}>{parameters}{end_tag}
|
||||
where
|
||||
|
||||
start_tag => `<function`
|
||||
parameters => a JSON dict with the function argument name
|
||||
as key and function argument value as value.
|
||||
end_tag => `</function>`
|
||||
|
||||
Here is an example,
|
||||
<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- Always add your sources when using search results to answer the user query
|
||||
|
||||
You are a helpful assistant.
|
||||
|
||||
Given the previous instructions, what is the weather in New York City?
|
||||
"""
|
||||
|
||||
# Change this once other backends support structural_tag
|
||||
if guided_decoding_backend.startswith("xgrammar"):
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
assert outputs is not None
|
||||
else:
|
||||
outputs = []
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# Search for function call pattern in the response
|
||||
function_call_pattern = r'<function=get_weather>(.*?)</function>'
|
||||
matches = re.findall(function_call_pattern, generated_text)
|
||||
|
||||
if not matches:
|
||||
print(f"Warning: No function calls found in response: "
|
||||
f"{generated_text!r}")
|
||||
continue
|
||||
|
||||
# Take the first function call if multiple are found
|
||||
json_str = matches[0]
|
||||
try:
|
||||
json_content = json.loads(json_str)
|
||||
assert "city" in json_content
|
||||
assert isinstance(json_content["city"], str)
|
||||
print(f"Found valid function call: {generated_text!r}")
|
||||
except (json.JSONDecodeError, AssertionError) as e:
|
||||
pytest.fail("Invalid function call format: "
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("model_name, tokenizer_mode",
|
||||
|
||||
Reference in New Issue
Block a user