[Bugfix] Backend option to disable xgrammar any_whitespace (#12744)
Signed-off-by: Wallas Santos <wallashss@ibm.com> Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import weakref
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
@@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_json_with_any_whitespace_disabled(llm):
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
clarifying_question: str
|
||||
cost_per_serving: str
|
||||
calories: str
|
||||
type_dish_ids: str
|
||||
type_meal_ids: str
|
||||
product_ids: list[str]
|
||||
exclude_product_ids: list[str]
|
||||
allergen_ids: list[str]
|
||||
total_cooking_time: str
|
||||
kitchen_ids: str
|
||||
holiday_ids: str
|
||||
|
||||
# Note: Without this setting, the response is sometimes full of `\n`
|
||||
# for some models. This option prevents that.
|
||||
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
||||
|
||||
schema = ResponseSchema.model_json_schema()
|
||||
guided_params = GuidedDecodingParams(json=schema,
|
||||
backend=\
|
||||
guided_decoding_backend)
|
||||
sampling_params = SamplingParams(max_tokens=2000,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=-1.1,
|
||||
repetition_penalty=1.3,
|
||||
guided_decoding=guided_params)
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
|
||||
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
|
||||
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
|
||||
outputs = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
assert "\n" not in generated_text
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
|
||||
Reference in New Issue
Block a user