[V1] Support disable_any_whtespace for guidance backend (#15584)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -15,7 +15,9 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
|
||||
GUIDED_DECODING_BACKENDS_V1 = [
|
||||
"xgrammar:disable-any-whitespace", "guidance:disable-any-whitespace"
|
||||
]
|
||||
MODELS_TO_TEST = [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
||||
]
|
||||
@@ -55,50 +57,8 @@ def test_guided_json_completion(
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
||||
def test_guided_json_completion_disable_any_whitespace(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
model_name: str,
|
||||
):
|
||||
if guided_decoding_backend != "xgrammar":
|
||||
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
|
||||
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
assert "\n" not in generated_text
|
||||
if 'disable-any-whitespace' in guided_decoding_backend:
|
||||
assert "\n" not in generated_text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
@@ -142,7 +102,7 @@ def test_guided_json_object(
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
allowed_types: tuple[type, ...] = (dict, )
|
||||
if guided_decoding_backend == "xgrammar":
|
||||
if guided_decoding_backend.startswith("xgrammar"):
|
||||
# TODO - we are currently too permissive with xgrammar and
|
||||
# allow # any valid json (typically comes back as a list or
|
||||
# object). We can fix this by specifying a jsonschema of
|
||||
@@ -170,7 +130,7 @@ def test_guided_json_unsupported_schema(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||
if guided_decoding_backend == "xgrammar":
|
||||
if guided_decoding_backend.startswith("xgrammar"):
|
||||
with pytest.raises(ValueError,
|
||||
match="The provided JSON schema contains features "
|
||||
"not supported by xgrammar."):
|
||||
|
||||
Reference in New Issue
Block a user