[V1] Logit processors for rejection sampler (#19482)

Signed-off-by: southfreebird <yvorott@gmail.com>
Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com>
Signed-off-by: Sergei Skvortsov <yvorott@gmail.com>
Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Sergei Skvortsov
2025-10-07 21:02:49 +01:00
committed by GitHub
parent 0c824fc46f
commit 6ebaf43ee4
12 changed files with 471 additions and 92 deletions

View File

@@ -215,3 +215,23 @@ def fake_apply_logitsprocs(
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits
def create_allowed_token_ids(
batch_size: int,
vocab_size: int,
num_allowed_token_ids: int,
device: torch.device,
) -> Optional[torch.Tensor]:
mask: Optional[torch.Tensor] = None
for i in range(batch_size):
if i % 2 == 1:
continue
if mask is None:
mask = torch.zeros(
(batch_size, vocab_size), dtype=torch.bool, device=device
)
start = min(i, vocab_size - 1)
end = min(i + num_allowed_token_ids, vocab_size - 1)
mask[i, start:end] = True
return mask