[Frontend] Bad words sampling parameter (#9717)

Signed-off-by: Vasily Alexeev <alvasian@yandex.ru>
This commit is contained in:
Vasiliy Alekseev
2024-10-26 19:29:38 +03:00
committed by GitHub
parent 55137e8ee3
commit 07e981fdf4
6 changed files with 339 additions and 16 deletions

View File

@@ -26,7 +26,8 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
@@ -34,6 +35,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
@@ -1963,6 +1965,7 @@ class LLMEngine:
logits_processors field. Returns the modified sampling params."""
logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:
logger.debug(
@@ -1984,7 +1987,7 @@ class LLMEngine:
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
processors = get_logits_processors(
processors = get_openai_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
@@ -1994,6 +1997,12 @@ class LLMEngine:
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None
if len(sampling_params.bad_words) > 0:
tokenizer = self.get_tokenizer(lora_request)
processors = get_bad_words_logits_processors(
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
logits_processors.extend(processors)
if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors