[V1] Support disable_any_whtespace for guidance backend (#15584)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant
2025-03-28 11:46:45 -04:00
committed by GitHub
parent 541d1df486
commit 7329ff5468
6 changed files with 44 additions and 117 deletions

View File

@@ -15,7 +15,9 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
GUIDED_DECODING_BACKENDS_V1 = [
"xgrammar:disable-any-whitespace", "guidance:disable-any-whitespace"
]
MODELS_TO_TEST = [
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
]
@@ -55,50 +57,8 @@ def test_guided_json_completion(
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_disable_any_whitespace(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
if guided_decoding_backend != "xgrammar":
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
if 'disable-any-whitespace' in guided_decoding_backend:
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@@ -142,7 +102,7 @@ def test_guided_json_object(
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
allowed_types: tuple[type, ...] = (dict, )
if guided_decoding_backend == "xgrammar":
if guided_decoding_backend.startswith("xgrammar"):
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
@@ -170,7 +130,7 @@ def test_guided_json_unsupported_schema(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
if guided_decoding_backend == "xgrammar":
if guided_decoding_backend.startswith("xgrammar"):
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):