[V0 deprecation] Guided decoding (#21347)

Signed-off-by: Reza Barazesh <rezabarazesh@meta.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Reza Barazesh
2025-07-29 03:15:30 -07:00
committed by GitHub
parent a4528f0cac
commit 37efc63b64
29 changed files with 103 additions and 2809 deletions

View File

@@ -4,43 +4,11 @@
import sys
from contextlib import nullcontext
import pytest
from vllm_test_utils import BlameResult, blame
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
V1 only supports xgrammar so this is irrelevant.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def run_normal_opt125m():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()
from vllm.sampling_params import GuidedDecodingParams
def run_normal():
@@ -67,20 +35,22 @@ def run_normal():
cleanup_dist_env_and_memory()
def run_lmfe(sample_regex):
def run_xgrammar(sample_regex):
# Create an LLM with guided decoding enabled.
llm = LLM(model="distilbert/distilgpt2",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
guided_decoding_backend="xgrammar",
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
prompt = f"Give an example IPv4 address with this regex: {sample_regex}"
guided_decoding = GuidedDecodingParams(regex=sample_regex)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=guided_decoding)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
prompts=[prompt] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
)
for output in outputs:
prompt = output.prompt
@@ -103,7 +73,7 @@ def test_lazy_outlines(sample_regex):
lambda: module_name in sys.modules) if use_blame else nullcontext()
with context as result:
run_normal()
run_lmfe(sample_regex)
run_xgrammar(sample_regex)
if use_blame:
assert isinstance(result, BlameResult)
print(f"the first import location is:\n{result.trace_stack}")