Support arbitrary json_object in OpenAI and Context Free Grammar (#3211)

This commit is contained in:
Simon Mo
2024-03-16 13:35:27 -07:00
committed by GitHub
parent 8e67598aa6
commit 120157fd2a
4 changed files with 177 additions and 50 deletions

View File

@@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={"type": "json_object"})
content = resp.choices[0].message.content
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""
completion = await client.completions.create(
model=MODEL_NAME,
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_grammar=simple_sql_grammar))
content = completion.choices[0].text
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(simple_sql_grammar)
parser.parse(content)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
assert content.strip() == ground_truth
if __name__ == "__main__":
pytest.main([__file__])