[V1] Logit processors for rejection sampler (#19482)
Signed-off-by: southfreebird <yvorott@gmail.com> Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com> Signed-off-by: Sergei Skvortsov <yvorott@gmail.com> Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.v1.sample.utils import create_allowed_token_ids
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@@ -21,7 +22,9 @@ def rejection_sampler():
|
||||
|
||||
|
||||
def create_logits_tensor(
|
||||
output_token_ids: list[list[int]], vocab_size: int = 100
|
||||
output_token_ids: list[list[int]],
|
||||
vocab_size: int = 100,
|
||||
token_idx_to_override: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Helper function to create logits tensor that
|
||||
will produce desired token ids on argmax"""
|
||||
@@ -33,15 +36,25 @@ def create_logits_tensor(
|
||||
for j, token_id in enumerate(tokens):
|
||||
logits[start_loc + j, token_id] = 100.0
|
||||
start_loc += len(tokens)
|
||||
if token_idx_to_override:
|
||||
logits[:, token_idx_to_override] = 99.0
|
||||
return logits
|
||||
|
||||
|
||||
def create_sampling_metadata(
|
||||
all_greedy: bool,
|
||||
output_token_ids: Optional[list[list[int]]] = None,
|
||||
prompt_token_ids: Optional[torch.Tensor] = None,
|
||||
spec_token_ids: Optional[torch.Tensor] = None,
|
||||
temperature: Optional[torch.Tensor] = None,
|
||||
top_k: Optional[torch.Tensor] = None,
|
||||
top_p: Optional[torch.Tensor] = None,
|
||||
generators: Optional[dict[int, Any]] = None,
|
||||
frequency_penalties: Optional[list[float]] = None,
|
||||
presence_penalties: Optional[list[float]] = None,
|
||||
repetition_penalties: Optional[list[float]] = None,
|
||||
bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None,
|
||||
allowed_token_ids_mask: Optional[torch.Tensor] = None,
|
||||
) -> SamplingMetadata:
|
||||
"""Create a v1 sampling metadata object with all_greedy set
|
||||
to the given value. Either all greedy or all random sampling
|
||||
@@ -53,6 +66,21 @@ def create_sampling_metadata(
|
||||
else:
|
||||
assert temperature is not None
|
||||
|
||||
if any([frequency_penalties, presence_penalties, repetition_penalties]):
|
||||
no_penalties = False
|
||||
|
||||
assert output_token_ids
|
||||
assert len(output_token_ids) > 0
|
||||
|
||||
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
|
||||
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
|
||||
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
|
||||
else:
|
||||
no_penalties = True
|
||||
frequency_penalties = torch.tensor([])
|
||||
presence_penalties = torch.tensor([])
|
||||
repetition_penalties = torch.tensor([])
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=all_greedy,
|
||||
@@ -61,14 +89,15 @@ def create_sampling_metadata(
|
||||
top_k=top_k,
|
||||
generators=generators,
|
||||
max_num_logprobs=0,
|
||||
no_penalties=False,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=torch.tensor([]),
|
||||
presence_penalties=torch.tensor([]),
|
||||
repetition_penalties=torch.tensor([]),
|
||||
output_token_ids=[],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
no_penalties=no_penalties,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
frequency_penalties=frequency_penalties,
|
||||
presence_penalties=presence_penalties,
|
||||
repetition_penalties=repetition_penalties,
|
||||
output_token_ids=[] if output_token_ids is None else output_token_ids,
|
||||
spec_token_ids=[] if spec_token_ids is None else spec_token_ids,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids={} if bad_words_token_ids is None else bad_words_token_ids,
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
|
||||
@@ -611,3 +640,136 @@ def test_top_p(rejection_sampler, top_p):
|
||||
unmasked_indices=top_p_indices,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
|
||||
########################### Tests for Logit Processors ###################
|
||||
def test_frequency_penalties(rejection_sampler):
|
||||
"""Test rejection sampling with frequency penalties"""
|
||||
spec_tokens = [[1, 1, 1], [], [1, 1, 1]]
|
||||
output_tokens = [[1, 1, 1, 1], [7], [1, 1, 1, 1]] # 1, 7 and 1 are the bonus tokens
|
||||
|
||||
num_requsts = len(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
|
||||
metadata = create_sampling_metadata(
|
||||
all_greedy=True,
|
||||
output_token_ids=[[2], [3], [4]],
|
||||
spec_token_ids=spec_tokens,
|
||||
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
|
||||
frequency_penalties=[1.5, 1.5, 0.7],
|
||||
presence_penalties=[0.0] * num_requsts,
|
||||
repetition_penalties=[1.0] * num_requsts,
|
||||
)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_bad_words(rejection_sampler):
|
||||
"""Test rejection sampling with bad words constraints"""
|
||||
spec_tokens = [[1, 2, 3], [1, 15, 3], [1, 2, 3]]
|
||||
output_tokens = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
|
||||
|
||||
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
|
||||
metadata = create_sampling_metadata(
|
||||
all_greedy=True,
|
||||
output_token_ids=[[2], [3], [4]],
|
||||
spec_token_ids=spec_tokens,
|
||||
bad_words_token_ids={
|
||||
0: [
|
||||
[
|
||||
2,
|
||||
]
|
||||
],
|
||||
1: [
|
||||
[
|
||||
2,
|
||||
]
|
||||
],
|
||||
# Do not apply bad words to the last request
|
||||
},
|
||||
)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[1, 15, -1, -1], [1, 15, 3, 4], [1, 2, 3, 4]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_allowed_token_ids(rejection_sampler):
|
||||
"""Test rejection sampling with allowed token ids"""
|
||||
spec_tokens = [[1, 2, 10], [10, 5, 3], [7, 10, 12]]
|
||||
output_tokens = [[1, 2, 10, 5], [10, 5, 10, 5], [7, 10, 12, 5]]
|
||||
# Not allowed tokens:
|
||||
# 0: 0-4
|
||||
# 1: 1-5
|
||||
# 2: 2-6
|
||||
num_allowed_token_ids = 5
|
||||
|
||||
# Use the token 15 as the sampler choose if a token rejected
|
||||
logits = create_logits_tensor(output_tokens, token_idx_to_override=15)
|
||||
|
||||
batch_size = len(output_tokens)
|
||||
_, vocab_size = logits.size()
|
||||
mask = create_allowed_token_ids(
|
||||
batch_size=batch_size,
|
||||
vocab_size=vocab_size,
|
||||
num_allowed_token_ids=num_allowed_token_ids,
|
||||
device=logits.device,
|
||||
)
|
||||
metadata = create_sampling_metadata(
|
||||
all_greedy=True,
|
||||
output_token_ids=[[], [], []],
|
||||
spec_token_ids=spec_tokens,
|
||||
allowed_token_ids_mask=mask,
|
||||
)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[15, -1, -1, -1], [10, 5, 10, -1], [7, 10, 12, 5]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
Reference in New Issue
Block a user