[Frontend] Bad words sampling parameter (#9717)

Signed-off-by: Vasily Alexeev <alvasian@yandex.ru>
This commit is contained in:
Vasiliy Alekseev
2024-10-26 19:29:38 +03:00
committed by GitHub
parent 55137e8ee3
commit 07e981fdf4
6 changed files with 339 additions and 16 deletions

View File

@@ -3,14 +3,14 @@ import copy
from dataclasses import dataclass
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
import msgspec
import torch
from pydantic import BaseModel
from typing_extensions import Annotated
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
logger = init_logger(__name__)
@@ -24,16 +24,6 @@ class SamplingType(IntEnum):
RANDOM_SEED = 2
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""
# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
@@ -139,6 +129,10 @@ class SamplingParams(
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
bad_words: List of words that are not allowed to be generated.
More precisely, only the last token of a corresponding
token sequence is not allowed when the next generated token
can complete the sequence.
include_stop_str_in_output: Whether to include the stop strings in
output text. Defaults to False.
ignore_eos: Whether to ignore the EOS token and continue generating
@@ -186,6 +180,7 @@ class SamplingParams(
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
bad_words: Optional[List[str]] = None
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
@@ -228,6 +223,7 @@ class SamplingParams(
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
bad_words: Optional[List[str]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
@@ -267,6 +263,7 @@ class SamplingParams(
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
bad_words=bad_words,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
@@ -298,26 +295,36 @@ class SamplingParams(
f"got n={self.n} and best_of={self.best_of}.")
self._real_n = self.n
self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.",
self.temperature, _MAX_TEMP, _MAX_TEMP)
self.temperature = max(self.temperature, _MAX_TEMP)
if self.seed == -1:
self.seed = None
else:
self.seed = self.seed
if self.stop is None:
self.stop = []
elif isinstance(self.stop, str):
self.stop = [self.stop]
else:
self.stop = list(self.stop)
if self.stop_token_ids is None:
self.stop_token_ids = []
else:
self.stop_token_ids = list(self.stop_token_ids)
if self.bad_words is None:
self.bad_words = []
else:
self.bad_words = list(self.bad_words)
self.logprobs = 1 if self.logprobs is True else self.logprobs
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
self.prompt_logprobs)
@@ -468,6 +475,7 @@ class SamplingParams(
f"seed={self.seed}, "
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "