[Frontend] Bad words sampling parameter (#9717)
Signed-off-by: Vasily Alexeev <alvasian@yandex.ru>
This commit is contained in:
@@ -3,14 +3,14 @@ import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -24,16 +24,6 @@ class SamplingType(IntEnum):
|
||||
RANDOM_SEED = 2
|
||||
|
||||
|
||||
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
||||
Callable[[List[int], List[int], torch.Tensor],
|
||||
torch.Tensor]]
|
||||
"""LogitsProcessor is a function that takes a list
|
||||
of previously generated tokens, the logits tensor
|
||||
for the next token and, optionally, prompt tokens as a
|
||||
first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
# maybe make msgspec?
|
||||
@dataclass
|
||||
class GuidedDecodingParams:
|
||||
@@ -139,6 +129,10 @@ class SamplingParams(
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
generated. The returned output will contain the stop tokens unless
|
||||
the stop tokens are special tokens.
|
||||
bad_words: List of words that are not allowed to be generated.
|
||||
More precisely, only the last token of a corresponding
|
||||
token sequence is not allowed when the next generated token
|
||||
can complete the sequence.
|
||||
include_stop_str_in_output: Whether to include the stop strings in
|
||||
output text. Defaults to False.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
@@ -186,6 +180,7 @@ class SamplingParams(
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
bad_words: Optional[List[str]] = None
|
||||
ignore_eos: bool = False
|
||||
max_tokens: Optional[int] = 16
|
||||
min_tokens: int = 0
|
||||
@@ -228,6 +223,7 @@ class SamplingParams(
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
bad_words: Optional[List[str]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
@@ -267,6 +263,7 @@ class SamplingParams(
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
bad_words=bad_words,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
ignore_eos=ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
@@ -298,26 +295,36 @@ class SamplingParams(
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
self._real_n = self.n
|
||||
self.n = self.best_of
|
||||
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
||||
self.temperature = max(self.temperature, _MAX_TEMP)
|
||||
|
||||
if self.seed == -1:
|
||||
self.seed = None
|
||||
else:
|
||||
self.seed = self.seed
|
||||
|
||||
if self.stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
else:
|
||||
self.stop = list(self.stop)
|
||||
|
||||
if self.stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(self.stop_token_ids)
|
||||
|
||||
if self.bad_words is None:
|
||||
self.bad_words = []
|
||||
else:
|
||||
self.bad_words = list(self.bad_words)
|
||||
|
||||
self.logprobs = 1 if self.logprobs is True else self.logprobs
|
||||
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
|
||||
self.prompt_logprobs)
|
||||
@@ -468,6 +475,7 @@ class SamplingParams(
|
||||
f"seed={self.seed}, "
|
||||
f"stop={self.stop}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"bad_words={self.bad_words}, "
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
|
||||
Reference in New Issue
Block a user