Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -19,8 +19,8 @@ to sample from."""
|
||||
|
||||
|
||||
def get_bad_words_logits_processors(
|
||||
bad_words: list[str],
|
||||
tokenizer: AnyTokenizer) -> list[LogitsProcessor]:
|
||||
bad_words: list[str], tokenizer: AnyTokenizer
|
||||
) -> list[LogitsProcessor]:
|
||||
bad_words_ids: list[list[int]] = list()
|
||||
|
||||
for bad_word in bad_words:
|
||||
@@ -31,15 +31,15 @@ def get_bad_words_logits_processors(
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
|
||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||
add_special_tokens=False)
|
||||
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
if (not add_prefix_space) or (
|
||||
add_prefix_space
|
||||
and prompt_token_ids[0] != bad_words_ids[-1][0]
|
||||
and len(prompt_token_ids) == len(bad_words_ids[-1])):
|
||||
add_prefix_space
|
||||
and prompt_token_ids[0] != bad_words_ids[-1][0]
|
||||
and len(prompt_token_ids) == len(bad_words_ids[-1])
|
||||
):
|
||||
bad_words_ids.append(prompt_token_ids)
|
||||
|
||||
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
|
||||
@@ -78,8 +78,9 @@ class NoBadWordsLogitsProcessor:
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
is_match = tuple(actual_prefix) == tuple(expected_prefix)
|
||||
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
|
||||
else self._NEUTRAL_LOGIT)
|
||||
last_token_bias[last_token_id] += (
|
||||
self._SMALLEST_LOGIT if is_match else self._NEUTRAL_LOGIT
|
||||
)
|
||||
|
||||
logits = logits + self.word_bias + last_token_bias
|
||||
|
||||
@@ -93,9 +94,9 @@ class NoBadWordsLogitsProcessor:
|
||||
|
||||
self._check_token_ids_bounds(vocab_size=vocab_size)
|
||||
|
||||
self.word_bias = torch.zeros((vocab_size, ),
|
||||
dtype=torch.float,
|
||||
device=logits.device)
|
||||
self.word_bias = torch.zeros(
|
||||
(vocab_size,), dtype=torch.float, device=logits.device
|
||||
)
|
||||
|
||||
for bad_word_ids in self.bad_words_ids:
|
||||
if len(bad_word_ids) == 1:
|
||||
@@ -116,4 +117,5 @@ class NoBadWordsLogitsProcessor:
|
||||
f" but the following tokens"
|
||||
f" were specified as bad: {invalid_token_ids}."
|
||||
f" All token id values should be integers satisfying:"
|
||||
f" 0 <= token_id < {vocab_size}.")
|
||||
f" 0 <= token_id < {vocab_size}."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user