[Frontend][Core] Move guided decoding params into sampling params (#8252)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -20,6 +20,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -477,6 +479,18 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
if isinstance(params, SamplingParams) and \
|
||||
params.guided_decoding is not None:
|
||||
# Guided decoding has an async implementation for building logits
|
||||
# processors in a separate threadpool.
|
||||
# We want to invoke that here instead of using the blocking
|
||||
# implementation in the LLMEngine
|
||||
params = await build_guided_decoding_logits_processor_async(
|
||||
sampling_params=params,
|
||||
tokenizer=self.get_tokenizer(lora_request),
|
||||
default_guided_backend=self.decoding_config.
|
||||
guided_decoding_backend)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
@@ -494,6 +508,36 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
async def build_guided_decoding_logits_processor_async(
|
||||
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||
default_guided_backend: str) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Modifies sampling params in-place and returns
|
||||
the modified sampling params."""
|
||||
if (guided_decoding := sampling_params.guided_decoding) is None:
|
||||
return sampling_params
|
||||
|
||||
logger.debug("Building guided decoding logits processor. "
|
||||
"Params: %s", guided_decoding)
|
||||
|
||||
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||
|
||||
processor = await get_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding, tokenizer=tokenizer)
|
||||
|
||||
if processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(processor)
|
||||
|
||||
# Unset guided decoding params after constructing the lp from them
|
||||
sampling_params.guided_decoding = None
|
||||
|
||||
return sampling_params
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for :class:`LLMEngine`.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user