[Frontend] Bad words sampling parameter (#9717)
Signed-off-by: Vasily Alexeev <alvasian@yandex.ru>
This commit is contained in:
@@ -26,7 +26,8 @@ from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.entrypoints.openai.logits_processors import (
|
||||
get_logits_processors as get_openai_logits_processors)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
@@ -34,6 +35,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputRegistry, PromptType)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_local_guided_decoding_logits_processor)
|
||||
@@ -1963,6 +1965,7 @@ class LLMEngine:
|
||||
logits_processors field. Returns the modified sampling params."""
|
||||
|
||||
logits_processors = []
|
||||
|
||||
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
||||
|
||||
logger.debug(
|
||||
@@ -1984,7 +1987,7 @@ class LLMEngine:
|
||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
|
||||
processors = get_logits_processors(
|
||||
processors = get_openai_logits_processors(
|
||||
logit_bias=sampling_params.logit_bias,
|
||||
allowed_token_ids=sampling_params.allowed_token_ids,
|
||||
tokenizer=tokenizer)
|
||||
@@ -1994,6 +1997,12 @@ class LLMEngine:
|
||||
sampling_params.logit_bias = None
|
||||
sampling_params.allowed_token_ids = None
|
||||
|
||||
if len(sampling_params.bad_words) > 0:
|
||||
tokenizer = self.get_tokenizer(lora_request)
|
||||
processors = get_bad_words_logits_processors(
|
||||
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
|
||||
logits_processors.extend(processors)
|
||||
|
||||
if logits_processors:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = logits_processors
|
||||
|
||||
Reference in New Issue
Block a user