[V1] Logit processors for rejection sampler (#19482)

Signed-off-by: southfreebird <yvorott@gmail.com>
Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com>
Signed-off-by: Sergei Skvortsov <yvorott@gmail.com>
Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Sergei Skvortsov
2025-10-07 21:02:49 +01:00
committed by GitHub
parent 0c824fc46f
commit 6ebaf43ee4
12 changed files with 471 additions and 92 deletions

View File

@@ -6,6 +6,7 @@ import pytest
import torch
import torch.nn.functional as F
from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
@@ -21,7 +22,9 @@ def rejection_sampler():
def create_logits_tensor(
output_token_ids: list[list[int]], vocab_size: int = 100
output_token_ids: list[list[int]],
vocab_size: int = 100,
token_idx_to_override: Optional[int] = None,
) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
@@ -33,15 +36,25 @@ def create_logits_tensor(
for j, token_id in enumerate(tokens):
logits[start_loc + j, token_id] = 100.0
start_loc += len(tokens)
if token_idx_to_override:
logits[:, token_idx_to_override] = 99.0
return logits
def create_sampling_metadata(
all_greedy: bool,
output_token_ids: Optional[list[list[int]]] = None,
prompt_token_ids: Optional[torch.Tensor] = None,
spec_token_ids: Optional[torch.Tensor] = None,
temperature: Optional[torch.Tensor] = None,
top_k: Optional[torch.Tensor] = None,
top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
frequency_penalties: Optional[list[float]] = None,
presence_penalties: Optional[list[float]] = None,
repetition_penalties: Optional[list[float]] = None,
bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None,
allowed_token_ids_mask: Optional[torch.Tensor] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
@@ -53,6 +66,21 @@ def create_sampling_metadata(
else:
assert temperature is not None
if any([frequency_penalties, presence_penalties, repetition_penalties]):
no_penalties = False
assert output_token_ids
assert len(output_token_ids) > 0
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
else:
no_penalties = True
frequency_penalties = torch.tensor([])
presence_penalties = torch.tensor([])
repetition_penalties = torch.tensor([])
return SamplingMetadata(
temperature=temperature,
all_greedy=all_greedy,
@@ -61,14 +89,15 @@ def create_sampling_metadata(
top_k=top_k,
generators=generators,
max_num_logprobs=0,
no_penalties=False,
prompt_token_ids=None,
frequency_penalties=torch.tensor([]),
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
allowed_token_ids_mask=None,
bad_words_token_ids={},
no_penalties=no_penalties,
prompt_token_ids=prompt_token_ids,
frequency_penalties=frequency_penalties,
presence_penalties=presence_penalties,
repetition_penalties=repetition_penalties,
output_token_ids=[] if output_token_ids is None else output_token_ids,
spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
logitsprocs=LogitsProcessors(),
)
@@ -611,3 +640,136 @@ def test_top_p(rejection_sampler, top_p):
unmasked_indices=top_p_indices,
sampling_metadata=sampling_metadata,
)
########################### Tests for Logit Processors ###################
def test_frequency_penalties(rejection_sampler):
"""Test rejection sampling with frequency penalties"""
spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]] # 1, 7 and 1 are the bonus tokens
num_requsts = len(spec_tokens)
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
metadata = create_sampling_metadata(
all_greedy=True,
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
frequency_penalties=[1.5, 1.5, 0.7],
presence_penalties=[0.0] * num_requsts,
repetition_penalties=[1.0] * num_requsts,
)
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
def test_bad_words(rejection_sampler):
"""Test rejection sampling with bad words constraints"""
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
metadata = create_sampling_metadata(
all_greedy=True,
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
bad_words_token_ids={
0: [
[
2,
]
],
1: [
[
2,
]
],
# Do not apply bad words to the last request
},
)
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
def test_allowed_token_ids(rejection_sampler):
"""Test rejection sampling with allowed token ids"""
spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
# Not allowed tokens:
# 0: 0-4
# 1: 1-5
# 2: 2-6
num_allowed_token_ids = 5
# Use the token 15 as the sampler choose if a token rejected
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
batch_size = len(output_tokens)
_, vocab_size = logits.size()
mask = create_allowed_token_ids(
batch_size=batch_size,
vocab_size=vocab_size,
num_allowed_token_ids=num_allowed_token_ids,
device=logits.device,
)
metadata = create_sampling_metadata(
all_greedy=True,
output_token_ids=[[], [], []],
spec_token_ids=spec_tokens,
allowed_token_ids_mask=mask,
)
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import pytest
import torch
from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessors
@@ -51,26 +50,6 @@ def _create_prompt_tokens_tensor(
)
def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros(
(batch_size, vocab_size), dtype=torch.bool, device=device
)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask
def _create_bad_words_token_ids(
batch_size: int,
vocab_size: int,
@@ -173,6 +152,7 @@ def _create_default_sampling_metadata(
prompt_token_ids, vocab_size, device
),
output_token_ids=output_token_ids,
spec_token_ids=[[] for _ in range(batch_size)],
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
@@ -241,7 +221,9 @@ def test_sampler_presence_penalty(
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
# Since all tokens initially have the same logits, the non-penalized
@@ -293,7 +275,9 @@ def test_sampler_frequency_penalty(
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
@@ -343,7 +327,9 @@ def test_sampler_repetition_penalty(
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
@@ -394,7 +380,7 @@ def test_sampler_allowed_token_ids(
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
mask = _create_allowed_token_ids(
mask = create_allowed_token_ids(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
num_allowed_token_ids=num_allowed_token_ids,
@@ -402,7 +388,9 @@ def test_sampler_allowed_token_ids(
)
sampling_metadata.allowed_token_ids_mask = mask
sampler = Sampler()
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
logits = sampler.apply_logits_processors(
fake_logits, sampling_metadata, predict_bonus_token=False
)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
@@ -444,7 +432,9 @@ def test_sampler_bad_words(
sampling_metadata, VOCAB_SIZE
)
sampler = Sampler()
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
logits = sampler.apply_logits_processors(
fake_logits, sampling_metadata, predict_bonus_token=False
)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]

View File

@@ -215,3 +215,23 @@ def fake_apply_logitsprocs(
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits
def create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros(
(batch_size, vocab_size), dtype=torch.bool, device=device
)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask