[V1] Logit processors for rejection sampler (#19482)
Signed-off-by: southfreebird <yvorott@gmail.com> Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com> Signed-off-by: Sergei Skvortsov <yvorott@gmail.com> Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -215,3 +215,23 @@ def fake_apply_logitsprocs(
|
||||
for processor in test_fakes.get_logitsprocs():
|
||||
logits = processor.apply(logits)
|
||||
return logits
|
||||
|
||||
|
||||
def create_allowed_token_ids(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
num_allowed_token_ids: int,
|
||||
device: torch.device,
|
||||
) -> Optional[torch.Tensor]:
|
||||
mask: Optional[torch.Tensor] = None
|
||||
for i in range(batch_size):
|
||||
if i % 2 == 1:
|
||||
continue
|
||||
if mask is None:
|
||||
mask = torch.zeros(
|
||||
(batch_size, vocab_size), dtype=torch.bool, device=device
|
||||
)
|
||||
start = min(i, vocab_size - 1)
|
||||
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||
mask[i, start:end] = True
|
||||
return mask
|
||||
|
||||
Reference in New Issue
Block a user