diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 61caffee4..d8ae57984 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -691,9 +691,13 @@ def test_frequency_penalties(rejection_sampler): def test_bad_words(rejection_sampler): - """Test rejection sampling with bad words constraints""" + """Test rejection sampling with bad words constraints. + + This test applies bad words to non-consecutive requests (0 and 2, but not 1) + to verify correct logit indexing when iterating over requests with bad words. + """ spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]] - output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] + output_tokens = [[1, 2, 3, 4], [1, 15, 3, 4], [1, 2, 3, 4]] logits = create_logits_tensor(output_tokens, token_idx_to_override=15) metadata = create_sampling_metadata( @@ -701,17 +705,9 @@ def test_bad_words(rejection_sampler): output_token_ids=[[2], [3], [4]], spec_token_ids=spec_tokens, bad_words_token_ids={ - 0: [ - [ - 2, - ] - ], - 1: [ - [ - 2, - ] - ], - # Do not apply bad words to the last request + 0: [[2]], + # Request 1 has no bad words (to test non-consecutive request handling) + 2: [[2]], }, ) bonus_token_tensor = torch.tensor( @@ -726,8 +722,11 @@ def test_bad_words(rejection_sampler): sampling_metadata=metadata, ) + # Request 0: bad word [2] matches prefix, so token 2 is rejected -> 15 + # Request 1: no bad words, all tokens match -> [1, 15, 3, 4] + # Request 2: bad word [2] matches prefix, so token 2 is rejected -> 15 expected = torch.tensor( - [[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]], + [[1, 15, -1, -1], [1, 15, 3, 4], [1, 15, -1, -1]], dtype=torch.int, device=logits.device, ) diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py index 8e2c798dd..56972e517 100644 --- a/vllm/v1/sample/ops/bad_words.py +++ b/vllm/v1/sample/ops/bad_words.py @@ -42,11 +42,16 @@ def apply_bad_words_with_drafts( num_draft_tokens: list[int], ) -> None: start_idx = 0 - for i, bad_words_ids in bad_words_token_ids.items(): - for draft_idx in range(num_draft_tokens[i]): - _apply_bad_words_single_batch( - logits[start_idx + draft_idx], - bad_words_ids, - past_tokens_ids[start_idx + draft_idx], - ) - start_idx += num_draft_tokens[i] + remaining = len(bad_words_token_ids) + for i, n in enumerate(num_draft_tokens): + if (bad_words_ids := bad_words_token_ids.get(i)) is not None: + for draft_idx in range(start_idx, start_idx + n): + _apply_bad_words_single_batch( + logits[draft_idx], + bad_words_ids, + past_tokens_ids[draft_idx], + ) + remaining -= 1 + if not remaining: + break + start_idx += n