[V0 deprecation] Guided decoding (#21347)
Signed-off-by: Reza Barazesh <rezabarazesh@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,552 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import weakref
|
||||
from enum import Enum
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
import regex as re
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
# Separate backends which support grammars vs ones
|
||||
# which only support regex based constraints in tests.
|
||||
GRAMMAR_DECODING_BACKENDS = [
|
||||
# (backend, disable_any_whitespace),
|
||||
("lm-format-enforcer", False),
|
||||
("xgrammar", True),
|
||||
("guidance", True),
|
||||
]
|
||||
|
||||
ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert re.fullmatch(sample_regex, generated_text) is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion(sample_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_complex_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an assignment grade "
|
||||
f"that fits this schema: {sample_complex_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_complex_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_definition_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for solving 8x + 7 = -23 "
|
||||
f"that fits this schema: {sample_definition_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_definition_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_enum_json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(prompts=[
|
||||
"Create a bug report JSON that fits this schema: "
|
||||
f"{sample_enum_json_schema}. Make it for a high priority critical bug."
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_enum_json_schema)
|
||||
|
||||
# Additional assertions to verify enum values
|
||||
assert output_json["status"] in ["active", "inactive", "pending"]
|
||||
assert output_json["priority"] in ["low", "medium", "high", "critical"]
|
||||
assert output_json["category"]["type"] in [
|
||||
"bug", "feature", "improvement"
|
||||
]
|
||||
assert output_json["category"]["severity"] in [1, 2, 3, 4, 5]
|
||||
for flag in output_json["flags"]:
|
||||
assert flag in ["urgent", "blocked", "needs_review", "approved"]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_choice_completion(sample_guided_choice, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
choice=sample_guided_choice,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts="The best language for type-safe systems programming is ",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert generated_text in sample_guided_choice
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_grammar(sample_sql_statements, llm,
|
||||
guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
grammar=sample_sql_statements,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_statements)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_options_request_deprecation_warning(sample_regex, llm):
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="guided_options_request"):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot set both"):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_disable_guided_decoding_fallback(sample_regex, llm):
|
||||
# see has_xgrammar_unsupported_json_features()
|
||||
unsupported_json = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"example": {
|
||||
"type": "string",
|
||||
"minLength": 5 # unsupported by xgrammar
|
||||
}
|
||||
}
|
||||
}
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=unsupported_json,
|
||||
backend="xgrammar",
|
||||
disable_fallback=True))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="xgrammar does not support advanced JSON schema features "
|
||||
"like string length, item limits, or property bounds."):
|
||||
llm.generate(prompts="This should fail",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
GRAMMAR_DECODING_BACKENDS)
|
||||
def test_guided_json_object(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=100,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json_object=True,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
|
||||
if disable_any_whitespace:
|
||||
assert "\n" not in generated_text
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
# A list is not what was intended, but is still valid
|
||||
# json.
|
||||
assert isinstance(parsed_json, (dict, list))
|
||||
|
||||
|
||||
class CarType(str, Enum):
|
||||
sedan = "sedan"
|
||||
suv = "SUV"
|
||||
truck = "Truck"
|
||||
coupe = "Coupe"
|
||||
|
||||
|
||||
class CarDescription(BaseModel):
|
||||
brand: str
|
||||
model: str
|
||||
car_type: CarType
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=json_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace))
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
|
||||
ALL_DECODING_BACKENDS)
|
||||
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
|
||||
disable_any_whitespace: bool):
|
||||
sample_output_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {
|
||||
"type": "integer",
|
||||
"minimum": 18,
|
||||
"maximum": 99
|
||||
},
|
||||
"score": {
|
||||
"type": "number",
|
||||
"minimum": 0.0,
|
||||
"maximum": 100.0
|
||||
},
|
||||
"zipcode": {
|
||||
"type": "string",
|
||||
"pattern": r"^\d{5}(-\d{4})?$"
|
||||
},
|
||||
},
|
||||
"required": ["age", "score", "zipcode"],
|
||||
}
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_output_schema,
|
||||
backend=guided_decoding_backend,
|
||||
disable_any_whitespace=disable_any_whitespace),
|
||||
)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
"Create a JSON object for a user with age, score, and zipcode."
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_output_schema)
|
||||
assert 18 <= output_json["age"] <= 99
|
||||
assert 0.0 <= output_json["score"] <= 100.0
|
||||
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
|
||||
is not None)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guidance_no_additional_properties(llm):
|
||||
schema = {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'a1': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a2': {
|
||||
'type': 'string'
|
||||
},
|
||||
'a3': {
|
||||
'type': 'string'
|
||||
}
|
||||
},
|
||||
'required': ['a1', 'a2', 'a3'],
|
||||
}
|
||||
|
||||
prompt = (
|
||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
|
||||
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
|
||||
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend, disable_additional_properties):
|
||||
guided_params = GuidedDecodingParams(
|
||||
json=schema,
|
||||
backend=backend,
|
||||
disable_any_whitespace=True,
|
||||
disable_additional_properties=disable_additional_properties)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
assert outputs is not None
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
assert generated_text is not None
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
jsonschema.validate(instance=parsed_json, schema=schema)
|
||||
return parsed_json
|
||||
|
||||
base_generated = generate_with_backend("guidance", False)
|
||||
assert "a1" in base_generated
|
||||
assert "a2" in base_generated
|
||||
assert "a3" in base_generated
|
||||
# by default additional keys are generated
|
||||
assert "a4" in base_generated
|
||||
assert "a5" in base_generated
|
||||
assert "a6" in base_generated
|
||||
|
||||
generated = generate_with_backend("guidance", True)
|
||||
assert "a1" in generated
|
||||
assert "a2" in generated
|
||||
assert "a3" in generated
|
||||
assert "a4" not in generated
|
||||
assert "a5" not in generated
|
||||
assert "a6" not in generated
|
||||
@@ -4,43 +4,11 @@
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
from vllm_test_utils import BlameResult, blame
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
V1 only supports xgrammar so this is irrelevant.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
def run_normal_opt125m():
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM without guided decoding as a baseline.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.3)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Destroy the LLM object and free up the GPU memory.
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
def run_normal():
|
||||
@@ -67,20 +35,22 @@ def run_normal():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def run_lmfe(sample_regex):
|
||||
def run_xgrammar(sample_regex):
|
||||
# Create an LLM with guided decoding enabled.
|
||||
llm = LLM(model="distilbert/distilgpt2",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
guided_decoding_backend="xgrammar",
|
||||
gpu_memory_utilization=0.3)
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
prompt = f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
guided_decoding = GuidedDecodingParams(regex=sample_regex)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=guided_decoding)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
prompts=[prompt] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
@@ -103,7 +73,7 @@ def test_lazy_outlines(sample_regex):
|
||||
lambda: module_name in sys.modules) if use_blame else nullcontext()
|
||||
with context as result:
|
||||
run_normal()
|
||||
run_lmfe(sample_regex)
|
||||
run_xgrammar(sample_regex)
|
||||
if use_blame:
|
||||
assert isinstance(result, BlameResult)
|
||||
print(f"the first import location is:\n{result.trace_stack}")
|
||||
|
||||
@@ -488,7 +488,9 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
sample_guided_choice):
|
||||
sample_guided_choice, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -524,8 +526,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
sample_json_schema):
|
||||
async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@@ -568,7 +572,10 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex):
|
||||
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@@ -653,7 +660,10 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Tool use is only supported in v1 engine")
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -741,131 +751,6 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema):
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_required_tool_use(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool, model_name: str):
|
||||
if is_v1_server:
|
||||
pytest.skip(
|
||||
"tool_choice='required' requires features unsupported on V1")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
)
|
||||
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
|
||||
# Streaming test
|
||||
stream = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
|
||||
sample_json_schema):
|
||||
@@ -948,7 +833,11 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip(
|
||||
"JSON schema response format is only supported in v1 engine")
|
||||
prompt = 'what is 1+1? The format is "result": 2'
|
||||
# Check that this prompt cannot lead to a valid JSON without json_schema
|
||||
for _ in range(2):
|
||||
|
||||
@@ -28,7 +28,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# but we're not testing generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -95,6 +95,14 @@ def server(default_server_args, request):
|
||||
os.environ['VLLM_USE_V1'] = original_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_v1_server(server):
|
||||
import os
|
||||
|
||||
# For completion tests, we assume v0 since there's no explicit v1 setup
|
||||
return os.environ.get('VLLM_USE_V1', '0') == '1'
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
@@ -631,7 +639,10 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
sample_json_schema, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
@@ -653,7 +664,10 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_regex):
|
||||
sample_regex, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
|
||||
@@ -674,7 +688,11 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice):
|
||||
sample_guided_choice,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
@@ -692,7 +710,9 @@ async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_grammar(client: openai.AsyncOpenAI,
|
||||
sample_sql_statements):
|
||||
sample_sql_statements, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided grammar is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
@@ -754,7 +774,11 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema, sample_regex):
|
||||
sample_json_schema, sample_regex,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
|
||||
@@ -9,6 +9,11 @@ import regex as re
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v1_only(monkeypatch):
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_prompt():
|
||||
model_name = "gpt2"
|
||||
@@ -37,24 +42,3 @@ async def test_out_of_vocab_token_ids():
|
||||
prompt=[999999],
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_multistep_with_guided_decoding():
|
||||
model_name = "gpt2"
|
||||
server_args = ["--enforce-eager", "--num-scheduler-steps", "8"]
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
with pytest.raises(
|
||||
openai.BadRequestError,
|
||||
match=re.compile(
|
||||
'.*Guided decoding .* multi-step decoding.*').pattern):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"response_format": {
|
||||
"type": "json_object"
|
||||
}})
|
||||
|
||||
Reference in New Issue
Block a user