[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -11,6 +11,8 @@ from pydantic import BaseModel
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -202,7 +204,6 @@ class SamplingParams(
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, list[str]]] = None
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
bad_words: Optional[list[str]] = None
|
||||
ignore_eos: bool = False
|
||||
max_tokens: Optional[int] = 16
|
||||
min_tokens: int = 0
|
||||
@@ -232,6 +233,10 @@ class SamplingParams(
|
||||
allowed_token_ids: Optional[list[int]] = None
|
||||
extra_args: Optional[dict[str, Any]] = None
|
||||
|
||||
# Fields used for bad words
|
||||
bad_words: Optional[list[str]] = None
|
||||
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
@@ -464,6 +469,46 @@ class SamplingParams(
|
||||
eos_ids.update(self.stop_token_ids)
|
||||
self.stop_token_ids = list(eos_ids)
|
||||
|
||||
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
if self.bad_words is None:
|
||||
return
|
||||
for bad_word in self.bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# Mistral tokenizers should not add special tokens
|
||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||
else:
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
text=prompt, add_special_tokens=False)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
if (not add_prefix_space) or (
|
||||
add_prefix_space and prompt_token_ids[0]
|
||||
!= self._bad_words_token_ids[-1][0]
|
||||
and len(prompt_token_ids) == len(
|
||||
self._bad_words_token_ids[-1])):
|
||||
self._bad_words_token_ids.append(prompt_token_ids)
|
||||
|
||||
invalid_token_ids = [
|
||||
token_id for bad_words_token_ids in self._bad_words_token_ids
|
||||
for token_id in bad_words_token_ids
|
||||
if token_id < 0 or token_id > tokenizer.max_token_id
|
||||
]
|
||||
if len(invalid_token_ids) > 0:
|
||||
raise ValueError(
|
||||
f"The model vocabulary size is {tokenizer.max_token_id+1},"
|
||||
f" but the following tokens"
|
||||
f" were specified as bad: {invalid_token_ids}."
|
||||
f" All token id values should be integers satisfying:"
|
||||
f" 0 <= token_id <= {tokenizer.max_token_id}.")
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
@@ -476,6 +521,11 @@ class SamplingParams(
|
||||
def all_stop_token_ids(self) -> set[int]:
|
||||
return self._all_stop_token_ids
|
||||
|
||||
@property
|
||||
def bad_words_token_ids(self) -> list[list[int]]:
|
||||
# For internal use only. Backward compatibility not guaranteed
|
||||
return self._bad_words_token_ids
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy, but maybe not the LogitsProcessor objects.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user