[V1] Logit processors for rejection sampler (#19482)

Signed-off-by: southfreebird <yvorott@gmail.com>
Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com>
Signed-off-by: Sergei Skvortsov <yvorott@gmail.com>
Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Sergei Skvortsov
2025-10-07 21:02:49 +01:00
committed by GitHub
parent 0c824fc46f
commit 6ebaf43ee4
12 changed files with 471 additions and 92 deletions

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import pytest
import torch
from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessors
@@ -51,26 +50,6 @@ def _create_prompt_tokens_tensor(
)
def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros(
(batch_size, vocab_size), dtype=torch.bool, device=device
)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask
def _create_bad_words_token_ids(
batch_size: int,
vocab_size: int,
@@ -173,6 +152,7 @@ def _create_default_sampling_metadata(
prompt_token_ids, vocab_size, device
),
output_token_ids=output_token_ids,
spec_token_ids=[[] for _ in range(batch_size)],
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
@@ -241,7 +221,9 @@ def test_sampler_presence_penalty(
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
# Since all tokens initially have the same logits, the non-penalized
@@ -293,7 +275,9 @@ def test_sampler_frequency_penalty(
sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
@@ -343,7 +327,9 @@ def test_sampler_repetition_penalty(
)
sampling_metadata.no_penalties = False
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = sampler.apply_penalties(
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
)
logits = logits.cpu()
for batch_idx in range(batch_size):
non_penalized_token_id = logits[batch_idx].argmax().item()
@@ -394,7 +380,7 @@ def test_sampler_allowed_token_ids(
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
mask = _create_allowed_token_ids(
mask = create_allowed_token_ids(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
num_allowed_token_ids=num_allowed_token_ids,
@@ -402,7 +388,9 @@ def test_sampler_allowed_token_ids(
)
sampling_metadata.allowed_token_ids_mask = mask
sampler = Sampler()
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
logits = sampler.apply_logits_processors(
fake_logits, sampling_metadata, predict_bonus_token=False
)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
@@ -444,7 +432,9 @@ def test_sampler_bad_words(
sampling_metadata, VOCAB_SIZE
)
sampler = Sampler()
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
logits = sampler.apply_logits_processors(
fake_logits, sampling_metadata, predict_bonus_token=False
)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]