[CI] Make JSON output tests less likely to fail (#17859)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -62,6 +62,16 @@ class CarDescription(BaseModel):
|
||||
car_type: CarType
|
||||
|
||||
|
||||
def _load_json(s: str, backend: str) -> str:
|
||||
if backend != "xgrammar":
|
||||
return json.loads(s)
|
||||
|
||||
# xgrammar specific workarounds
|
||||
# https://github.com/mlc-ai/xgrammar/issues/286
|
||||
s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s)
|
||||
return json.loads(s)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
|
||||
@@ -102,7 +112,7 @@ def test_structured_output(
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||
outputs = llm.generate(prompts=[
|
||||
(f"Give an example JSON for an employee profile that fits this "
|
||||
@@ -131,7 +141,7 @@ def test_structured_output(
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=100,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
|
||||
@@ -161,7 +171,7 @@ def test_structured_output(
|
||||
#
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||
if guided_decoding_backend.startswith("xgrammar"):
|
||||
with pytest.raises(ValueError,
|
||||
@@ -376,12 +386,13 @@ def test_structured_output(
|
||||
"minLength": min_length
|
||||
}
|
||||
},
|
||||
"required": ["description"]
|
||||
"required": ["description"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
|
||||
outputs = llm.generate(
|
||||
@@ -417,7 +428,8 @@ def test_structured_output(
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": False
|
||||
},
|
||||
"end": "</function>"
|
||||
}],
|
||||
@@ -426,7 +438,7 @@ def test_structured_output(
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=100,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user