[Misc] add fixture to guided processor tests (#6341)
This commit is contained in:
@@ -22,53 +22,6 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
TEST_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
},
|
||||
"work history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "string"
|
||||
},
|
||||
"position": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["company", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["name", "age", "skills", "work history"]
|
||||
}
|
||||
|
||||
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
TEST_CHOICE = [
|
||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
|
||||
"Swift", "Kotlin"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
@@ -408,7 +361,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -422,10 +376,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
extra_body=dict(guided_choice=sample_guided_choice,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice1 = chat_completion.choices[0].message.content
|
||||
assert choice1 in TEST_CHOICE
|
||||
assert choice1 in sample_guided_choice
|
||||
|
||||
messages.append({"role": "assistant", "content": choice1})
|
||||
messages.append({
|
||||
@@ -436,10 +390,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
extra_body=dict(guided_choice=sample_guided_choice,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice2 = chat_completion.choices[0].message.content
|
||||
assert choice2 in TEST_CHOICE
|
||||
assert choice2 in sample_guided_choice
|
||||
assert choice1 != choice2
|
||||
|
||||
|
||||
@@ -447,7 +401,8 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -456,18 +411,18 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
f"fits this schema: {sample_json_schema}"
|
||||
}]
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
extra_body=dict(guided_json=sample_json_schema,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json1 = json.loads(message.content)
|
||||
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
|
||||
jsonschema.validate(instance=json1, schema=sample_json_schema)
|
||||
|
||||
messages.append({"role": "assistant", "content": message.content})
|
||||
messages.append({
|
||||
@@ -480,12 +435,12 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
extra_body=dict(guided_json=sample_json_schema,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json2 = json.loads(message.content)
|
||||
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
|
||||
jsonschema.validate(instance=json2, schema=sample_json_schema)
|
||||
assert json1["name"] != json2["name"]
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
@@ -494,7 +449,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_chat(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
guided_decoding_backend: str, sample_regex):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -502,17 +457,17 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example IP address with this regex: {TEST_REGEX}"
|
||||
f"Give an example IP address with this regex: {sample_regex}"
|
||||
}]
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
extra_body=dict(guided_regex=sample_regex,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip1 = chat_completion.choices[0].message.content
|
||||
assert ip1 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
||||
assert re.fullmatch(sample_regex, ip1) is not None
|
||||
|
||||
messages.append({"role": "assistant", "content": ip1})
|
||||
messages.append({"role": "user", "content": "Give me a different one"})
|
||||
@@ -520,11 +475,11 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
extra_body=dict(guided_regex=sample_regex,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip2 = chat_completion.choices[0].message.content
|
||||
assert ip2 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
||||
assert re.fullmatch(sample_regex, ip2) is not None
|
||||
assert ip1 != ip2
|
||||
|
||||
|
||||
@@ -553,7 +508,8 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -569,7 +525,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
extra_body=dict(guided_choice=sample_guided_choice,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert chat_completion.choices[0].logprobs is not None
|
||||
@@ -585,7 +541,8 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -594,7 +551,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
f"fits this schema: {sample_json_schema}"
|
||||
}]
|
||||
|
||||
# non-streaming
|
||||
@@ -608,7 +565,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
@@ -621,7 +578,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
|
||||
assert len(message.content) == 0
|
||||
json_string = message.tool_calls[0].function.arguments
|
||||
json1 = json.loads(json_string)
|
||||
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
|
||||
jsonschema.validate(instance=json1, schema=sample_json_schema)
|
||||
|
||||
messages.append({"role": "assistant", "content": json_string})
|
||||
messages.append({
|
||||
@@ -642,7 +599,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
@@ -667,7 +624,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
json2 = json.loads("".join(output))
|
||||
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
|
||||
jsonschema.validate(instance=json2, schema=sample_json_schema)
|
||||
assert json1["name"] != json2["name"]
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
@@ -675,7 +632,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
||||
async def test_required_tool_use_not_yet_supported(
|
||||
client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
||||
client: openai.AsyncOpenAI, guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -684,7 +642,7 @@ async def test_required_tool_use_not_yet_supported(
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
f"fits this schema: {sample_json_schema}"
|
||||
}]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
@@ -697,7 +655,7 @@ async def test_required_tool_use_not_yet_supported(
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice="required")
|
||||
@@ -712,7 +670,7 @@ async def test_required_tool_use_not_yet_supported(
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice="auto")
|
||||
@@ -720,8 +678,9 @@ async def test_required_tool_use_not_yet_supported(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
||||
async def test_inconsistent_tool_choice_and_tools(
|
||||
client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
||||
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -730,7 +689,7 @@ async def test_inconsistent_tool_choice_and_tools(
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {TEST_SCHEMA}"
|
||||
f"fits this schema: {sample_json_schema}"
|
||||
}]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
@@ -755,7 +714,7 @@ async def test_inconsistent_tool_choice_and_tools(
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": TEST_SCHEMA
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
|
||||
Reference in New Issue
Block a user