[BugFix] Fix bad words with speculative decoding (#31908)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler):
|
||||
|
||||
|
||||
def test_bad_words(rejection_sampler):
|
||||
"""Test rejection sampling with bad words constraints"""
|
||||
"""Test rejection sampling with bad words constraints.
|
||||
|
||||
This test applies bad words to non-consecutive requests (0 and 2, but not 1)
|
||||
to verify correct logit indexing when iterating over requests with bad words.
|
||||
"""
|
||||
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
|
||||
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
|
||||
output_tokens = [[1, 2, 3, 4], [1, 15, 3, 4], [1, 2, 3, 4]]
|
||||
|
||||
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
|
||||
metadata = create_sampling_metadata(
|
||||
@@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler):
|
||||
output_token_ids=[[2], [3], [4]],
|
||||
spec_token_ids=spec_tokens,
|
||||
bad_words_token_ids={
|
||||
0: [
|
||||
[
|
||||
2,
|
||||
]
|
||||
],
|
||||
1: [
|
||||
[
|
||||
2,
|
||||
]
|
||||
],
|
||||
# Do not apply bad words to the last request
|
||||
0: [[2]],
|
||||
# Request 1 has no bad words (to test non-consecutive request handling)
|
||||
2: [[2]],
|
||||
},
|
||||
)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
@@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler):
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
|
||||
# Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
|
||||
# Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
|
||||
# Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
|
||||
expected = torch.tensor(
|
||||
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
|
||||
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 15, -1, -1]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
|
||||
@@ -42,11 +42,16 @@ def apply_bad_words_with_drafts(
|
||||
num_draft_tokens: list[int],
|
||||
) -> None:
|
||||
start_idx = 0
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
for draft_idx in range(num_draft_tokens[i]):
|
||||
_apply_bad_words_single_batch(
|
||||
logits[start_idx + draft_idx],
|
||||
bad_words_ids,
|
||||
past_tokens_ids[start_idx + draft_idx],
|
||||
)
|
||||
start_idx += num_draft_tokens[i]
|
||||
remaining = len(bad_words_token_ids)
|
||||
for i, n in enumerate(num_draft_tokens):
|
||||
if (bad_words_ids := bad_words_token_ids.get(i)) is not None:
|
||||
for draft_idx in range(start_idx, start_idx + n):
|
||||
_apply_bad_words_single_batch(
|
||||
logits[draft_idx],
|
||||
bad_words_ids,
|
||||
past_tokens_ids[draft_idx],
|
||||
)
|
||||
remaining -= 1
|
||||
if not remaining:
|
||||
break
|
||||
start_idx += n
|
||||
|
||||
Reference in New Issue
Block a user