Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Cyrus Leung
2025-11-15 14:47:41 +08:00
committed by GitHub
parent 6965ef436f
commit 98b4d389ed
15 changed files with 122 additions and 91 deletions

View File

@@ -3,6 +3,7 @@
from dataclasses import replace
import numpy as np
import torch
import torch.nn as nn
@@ -204,7 +205,7 @@ class RejectionSampler(nn.Module):
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
) -> list[np.ndarray]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
@@ -220,10 +221,7 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)]
def apply_logits_processors(
self,