[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -77,6 +77,49 @@ def _create_allowed_token_ids(
|
||||
return mask
|
||||
|
||||
|
||||
def _create_bad_words_token_ids(
|
||||
batch_size: int, vocab_size: int,
|
||||
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
|
||||
bad_words_token_ids = {}
|
||||
for batch_idx in range(batch_size):
|
||||
token_ids_single_batch = []
|
||||
for bad_words_length in bad_words_lengths:
|
||||
token_ids = np.random.choice(vocab_size,
|
||||
size=bad_words_length,
|
||||
replace=True).tolist()
|
||||
token_ids_single_batch.append(token_ids)
|
||||
bad_words_token_ids[batch_idx] = token_ids_single_batch
|
||||
if batch_size >= 2:
|
||||
# Test no bad_words for some batch
|
||||
no_bad_words_batch_idx = np.random.choice(batch_size)
|
||||
bad_words_token_ids.pop(no_bad_words_batch_idx, None)
|
||||
return bad_words_token_ids
|
||||
|
||||
|
||||
def _update_output_token_ids_for_bad_words(
|
||||
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
|
||||
bad_words_last_tokens = {}
|
||||
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
|
||||
output_token_ids = metadata.output_token_ids[batch_idx]
|
||||
bad_words_last_token: list[int] = []
|
||||
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
|
||||
if len(bad_word_token_ids) == 1:
|
||||
# Single token id always affects logits
|
||||
bad_words_last_token.append(bad_word_token_ids[0])
|
||||
else:
|
||||
prefix_length = len(bad_word_token_ids) - 1
|
||||
has_bad_words = np.random.choice([True, False])
|
||||
if has_bad_words:
|
||||
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
|
||||
bad_words_last_token.append(bad_word_token_ids[-1])
|
||||
break # Maximum one update to output_token_ids
|
||||
else: # Make sure no accidental match to bad words
|
||||
output_token_ids[-1] = (bad_word_token_ids[-2] +
|
||||
1) % vocab_size
|
||||
bad_words_last_tokens[batch_idx] = bad_words_last_token
|
||||
return bad_words_last_tokens
|
||||
|
||||
|
||||
def _create_default_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
@@ -112,6 +155,7 @@ def _create_default_sampling_metadata(
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
@@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
||||
"inf"), f"{batch_idx}, {token_id}"
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
|
||||
def test_sampler_bad_words(device: str, batch_size: int,
|
||||
bad_words_lengths: list[tuple[int]]):
|
||||
"""
|
||||
Test to verify that when the bad words restriction is present, tokens
|
||||
are penalized based on their match with the bad words.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
|
||||
batch_size, VOCAB_SIZE, bad_words_lengths)
|
||||
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
|
||||
sampling_metadata, VOCAB_SIZE)
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if (batch_idx in bad_words_last_tokens
|
||||
and token_id in bad_words_last_tokens[batch_idx]):
|
||||
assert logits_for_req[token_id] == -float("inf")
|
||||
else:
|
||||
assert logits_for_req[token_id] != -float("inf")
|
||||
|
||||
Reference in New Issue
Block a user