[Bugfix] Backend option to disable xgrammar any_whitespace (#12744)

Signed-off-by: Wallas Santos <wallashss@ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Wallas Henrique
2025-02-26 15:52:34 -03:00
committed by GitHub
parent d08b285adf
commit 4cb6fa0a9c
3 changed files with 88 additions and 3 deletions

View File

@@ -6,6 +6,7 @@ import weakref
import jsonschema
import pytest
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
@@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
def test_json_with_any_whitespace_disabled(llm):
class ResponseSchema(BaseModel):
clarifying_question: str
cost_per_serving: str
calories: str
type_dish_ids: str
type_meal_ids: str
product_ids: list[str]
exclude_product_ids: list[str]
allergen_ids: list[str]
total_cooking_time: str
kitchen_ids: str
holiday_ids: str
# Note: Without this setting, the response is sometimes full of `\n`
# for some models. This option prevents that.
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
schema = ResponseSchema.model_json_schema()
guided_params = GuidedDecodingParams(json=schema,
backend=\
guided_decoding_backend)
sampling_params = SamplingParams(max_tokens=2000,
frequency_penalty=0,
presence_penalty=-1.1,
repetition_penalty=1.3,
guided_decoding=guided_params)
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)