[BugFix] Fix bad words with speculative decoding (#31908)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-01-07 12:46:42 -08:00
committed by GitHub
parent 6170d47d22
commit 10ef65eded
2 changed files with 26 additions and 22 deletions

View File

@@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler):
def test_bad_words(rejection_sampler):
"""Test rejection sampling with bad words constraints"""
"""Test rejection sampling with bad words constraints.
This test applies bad words to non-consecutive requests (0 and 2, but not 1)
to verify correct logit indexing when iterating over requests with bad words.
"""
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
output_tokens = [[1, 2, 3, 4], [1, 15, 3, 4], [1, 2, 3, 4]]
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
metadata = create_sampling_metadata(
@@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler):
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
bad_words_token_ids={
0: [
[
2,
]
],
1: [
[
2,
]
],
# Do not apply bad words to the last request
0: [[2]],
# Request 1 has no bad words (to test non-consecutive request handling)
2: [[2]],
},
)
bonus_token_tensor = torch.tensor(
@@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler):
sampling_metadata=metadata,
)
# Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15
# Request 1: no bad words, all tokens match -> [1, 15, 3, 4]
# Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15
expected = torch.tensor(
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 15, -1, -1]],
dtype=torch.int,
device=logits.device,
)