[Frontend][Core] Move guided decoding params into sampling params (#8252)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
@@ -33,6 +34,8 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
@@ -843,6 +846,9 @@ class LLMEngine:
|
||||
raise ValueError(f"Cannot request more than "
|
||||
f"{max_logprobs} logprobs.")
|
||||
|
||||
sampling_params = self._build_logits_processors(
|
||||
sampling_params, lora_request)
|
||||
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
@@ -1895,3 +1901,51 @@ class LLMEngine:
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
def _build_logits_processors(
|
||||
self, sampling_params: SamplingParams,
|
||||
lora_request: Optional[LoRARequest]) -> SamplingParams:
|
||||
"""Constructs logits processors based on the guided_decoding,
|
||||
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||
those fields and adds the constructed logits processors to the
|
||||
logits_processors field. Returns the modified sampling params."""
|
||||
|
||||
logits_processors = []
|
||||
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
||||
|
||||
logger.debug(
|
||||
"Building guided decoding logits processor in "
|
||||
"LLMEngine. Params: %s", guided_decoding)
|
||||
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
guided_decoding.backend = guided_decoding.backend or \
|
||||
self.decoding_config.guided_decoding_backend
|
||||
|
||||
processor = get_local_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding, tokenizer=tokenizer)
|
||||
if processor:
|
||||
logits_processors.append(processor)
|
||||
|
||||
# Unset so this doesn't get passed down to the model
|
||||
sampling_params.guided_decoding = None
|
||||
|
||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
|
||||
processors = get_logits_processors(
|
||||
logit_bias=sampling_params.logit_bias,
|
||||
allowed_token_ids=sampling_params.allowed_token_ids,
|
||||
tokenizer=tokenizer)
|
||||
logits_processors.extend(processors)
|
||||
|
||||
# Unset so these don't get passed down to the model
|
||||
sampling_params.logit_bias = None
|
||||
sampling_params.allowed_token_ids = None
|
||||
|
||||
if logits_processors:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = logits_processors
|
||||
else:
|
||||
sampling_params.logits_processors.extend(logits_processors)
|
||||
|
||||
return sampling_params
|
||||
|
||||
Reference in New Issue
Block a user