[V0 deprecation] Guided decoding (#21347)
Signed-off-by: Reza Barazesh <rezabarazesh@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
import weakref
|
||||
from functools import partial
|
||||
@@ -24,8 +23,6 @@ from vllm.inputs import PromptType
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -469,19 +466,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
# Guided decoding has an async implementation for building logits
|
||||
# processors in a separate threadpool.
|
||||
# We want to invoke that here instead of using the blocking
|
||||
# implementation in the LLMEngine
|
||||
params = await build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=await self.get_tokenizer_async(lora_request),
|
||||
default_guided_backend=self.decoding_config.backend,
|
||||
reasoning_backend=self.decoding_config.reasoning_backend,
|
||||
model_config=self.model_config)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
@@ -503,48 +487,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def build_guided_decoding_logits_processor_async(
|
||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||
default_guided_backend: str, reasoning_backend: Optional[str],
|
||||
model_config: ModelConfig) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Modifies sampling params in-place and returns
|
||||
the modified sampling params."""
|
||||
if sampling_params.guided_decoding is None:
|
||||
return sampling_params
|
||||
|
||||
# Defensively copy sampling params since guided decoding logits
|
||||
# processors can have different state for each request
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
guided_decoding = sampling_params.guided_decoding
|
||||
|
||||
logger.debug(
|
||||
"Building guided decoding logits processor. "
|
||||
"guided_decoding: %s%s", guided_decoding,
|
||||
f", reasoning_backend: {reasoning_backend}"
|
||||
if reasoning_backend is not None else "")
|
||||
|
||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||
|
||||
processor = await get_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding,
|
||||
tokenizer=tokenizer,
|
||||
reasoning_backend=reasoning_backend,
|
||||
model_config=model_config)
|
||||
|
||||
if processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(processor)
|
||||
|
||||
# Unset guided decoding params after constructing the lp from them
|
||||
sampling_params.guided_decoding = None
|
||||
|
||||
return sampling_params
|
||||
|
||||
|
||||
class AsyncLLMEngine(EngineClient):
|
||||
"""An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine].
|
||||
|
||||
@@ -1028,7 +970,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
```
|
||||
# Please refer to entrypoints/api_server.py for
|
||||
# the complete example.
|
||||
|
||||
|
||||
# initialize the engine and the example input
|
||||
# note that engine_args here is AsyncEngineArgs instance
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
@@ -1036,13 +978,13 @@ class AsyncLLMEngine(EngineClient):
|
||||
"input": "What is LLM?",
|
||||
"request_id": 0,
|
||||
}
|
||||
|
||||
|
||||
# start the generation
|
||||
results_generator = engine.encode(
|
||||
example_input["input"],
|
||||
PoolingParams(),
|
||||
example_input["request_id"])
|
||||
|
||||
|
||||
# get the results
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
@@ -1052,7 +994,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
# Return or raise an error
|
||||
...
|
||||
final_output = request_output
|
||||
|
||||
|
||||
# Process and return the final output
|
||||
...
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user