[V1] Support bad_words in sampler (#13376)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
22quinn
2025-03-08 14:50:26 -08:00
committed by GitHub
parent 9513290032
commit eb8b5eb183
13 changed files with 266 additions and 28 deletions

View File

@@ -11,6 +11,8 @@ from pydantic import BaseModel
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__)
@@ -202,7 +204,6 @@ class SamplingParams(
seed: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
stop_token_ids: Optional[list[int]] = None
bad_words: Optional[list[str]] = None
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
@@ -232,6 +233,10 @@ class SamplingParams(
allowed_token_ids: Optional[list[int]] = None
extra_args: Optional[dict[str, Any]] = None
# Fields used for bad words
bad_words: Optional[list[str]] = None
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
@staticmethod
def from_optional(
n: Optional[int] = 1,
@@ -464,6 +469,46 @@ class SamplingParams(
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
if self.bad_words is None:
return
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space and prompt_token_ids[0]
!= self._bad_words_token_ids[-1][0]
and len(prompt_token_ids) == len(
self._bad_words_token_ids[-1])):
self._bad_words_token_ids.append(prompt_token_ids)
invalid_token_ids = [
token_id for bad_words_token_ids in self._bad_words_token_ids
for token_id in bad_words_token_ids
if token_id < 0 or token_id > tokenizer.max_token_id
]
if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {tokenizer.max_token_id+1},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id <= {tokenizer.max_token_id}.")
@cached_property
def sampling_type(self) -> SamplingType:
if self.temperature < _SAMPLING_EPS:
@@ -476,6 +521,11 @@ class SamplingParams(
def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids
@property
def bad_words_token_ids(self) -> list[list[int]]:
# For internal use only. Backward compatibility not guaranteed
return self._bad_words_token_ids
def clone(self) -> "SamplingParams":
"""Deep copy, but maybe not the LogitsProcessor objects.