[Frontend][Core] Move guided decoding params into sampling params (#8252)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Joe Runde
2024-09-30 19:34:25 -06:00
committed by GitHub
parent bce324487a
commit 062c89e7c9
16 changed files with 441 additions and 281 deletions

View File

@@ -25,6 +25,7 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
@@ -33,6 +34,8 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
@@ -843,6 +846,9 @@ class LLMEngine:
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
sampling_params = self._build_logits_processors(
sampling_params, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
@@ -1895,3 +1901,51 @@ class LLMEngine:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""
logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:
logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)
tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
logits_processors.append(processor)
# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
processors = get_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)
# Unset so these don't get passed down to the model
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None
if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)
return sampling_params