[V0 deprecation] Guided decoding (#21347)

Signed-off-by: Reza Barazesh <rezabarazesh@meta.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Reza Barazesh
2025-07-29 03:15:30 -07:00
committed by GitHub
parent a4528f0cac
commit 37efc63b64
29 changed files with 103 additions and 2809 deletions

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import time
import weakref
from functools import partial
@@ -24,8 +23,6 @@ from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
@@ -469,19 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
tokenization_kwargs=tokenization_kwargs,
)
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.backend,
reasoning_backend=self.decoding_config.reasoning_backend,
model_config=self.model_config)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
@@ -503,48 +487,6 @@ class _AsyncLLMEngine(LLMEngine):
raise NotImplementedError
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str, reasoning_backend: Optional[str],
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if sampling_params.guided_decoding is None:
return sampling_params
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding
logger.debug(
"Building guided decoding logits processor. "
"guided_decoding: %s%s", guided_decoding,
f", reasoning_backend: {reasoning_backend}"
if reasoning_backend is not None else "")
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding,
tokenizer=tokenizer,
reasoning_backend=reasoning_backend,
model_config=model_config)
if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)
# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None
return sampling_params
class AsyncLLMEngine(EngineClient):
"""An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
@@ -1028,7 +970,7 @@ class AsyncLLMEngine(EngineClient):
```
# Please refer to entrypoints/api_server.py for
# the complete example.
# initialize the engine and the example input
# note that engine_args here is AsyncEngineArgs instance
engine = AsyncLLMEngine.from_engine_args(engine_args)
@@ -1036,13 +978,13 @@ class AsyncLLMEngine(EngineClient):
"input": "What is LLM?",
"request_id": 0,
}
# start the generation
results_generator = engine.encode(
example_input["input"],
PoolingParams(),
example_input["request_id"])
# get the results
final_output = None
async for request_output in results_generator:
@@ -1052,7 +994,7 @@ class AsyncLLMEngine(EngineClient):
# Return or raise an error
...
final_output = request_output
# Process and return the final output
...
```