[V1] Support bad_words in sampler (#13376)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
22quinn
2025-03-08 14:50:26 -08:00
committed by GitHub
parent 9513290032
commit eb8b5eb183
13 changed files with 266 additions and 28 deletions

View File

@@ -77,6 +77,49 @@ def _create_allowed_token_ids(
return mask
def _create_bad_words_token_ids(
batch_size: int, vocab_size: int,
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
bad_words_token_ids = {}
for batch_idx in range(batch_size):
token_ids_single_batch = []
for bad_words_length in bad_words_lengths:
token_ids = np.random.choice(vocab_size,
size=bad_words_length,
replace=True).tolist()
token_ids_single_batch.append(token_ids)
bad_words_token_ids[batch_idx] = token_ids_single_batch
if batch_size >= 2:
# Test no bad_words for some batch
no_bad_words_batch_idx = np.random.choice(batch_size)
bad_words_token_ids.pop(no_bad_words_batch_idx, None)
return bad_words_token_ids
def _update_output_token_ids_for_bad_words(
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
bad_words_last_tokens = {}
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
output_token_ids = metadata.output_token_ids[batch_idx]
bad_words_last_token: list[int] = []
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
if len(bad_word_token_ids) == 1:
# Single token id always affects logits
bad_words_last_token.append(bad_word_token_ids[0])
else:
prefix_length = len(bad_word_token_ids) - 1
has_bad_words = np.random.choice([True, False])
if has_bad_words:
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
bad_words_last_token.append(bad_word_token_ids[-1])
break # Maximum one update to output_token_ids
else: # Make sure no accidental match to bad words
output_token_ids[-1] = (bad_word_token_ids[-2] +
1) % vocab_size
bad_words_last_tokens[batch_idx] = bad_words_last_token
return bad_words_last_tokens
def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
@@ -112,6 +155,7 @@ def _create_default_sampling_metadata(
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
return fake_sampling_metadata
@@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
"inf"), f"{batch_idx}, {token_id}"
else:
assert logits_for_req[token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
def test_sampler_bad_words(device: str, batch_size: int,
bad_words_lengths: list[tuple[int]]):
"""
Test to verify that when the bad words restriction is present, tokens
are penalized based on their match with the bad words.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
batch_size, VOCAB_SIZE, bad_words_lengths)
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
sampling_metadata, VOCAB_SIZE)
sampler = Sampler()
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
for token_id in range(VOCAB_SIZE):
if (batch_idx in bad_words_last_tokens
and token_id in bad_words_last_tokens[batch_idx]):
assert logits_for_req[token_id] == -float("inf")
else:
assert logits_for_req[token_id] != -float("inf")