[V1] Logits processor docs (#22919)
Signed-off-by: Andrew Feldman <afeldman@redhat.com> Signed-off-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Co-authored-by: Joseph Marinier <Joseph.Marinier@gmail.com>
This commit is contained in:
@@ -56,7 +56,6 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
self.req_info: dict[int, int] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Never impacts greedy sampling"""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
@@ -75,13 +74,12 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
rows_list = list(self.req_info.keys())
|
||||
cols = torch.tensor(
|
||||
[self.req_info[i] for i in rows_list],
|
||||
dtype=torch.long,
|
||||
device=logits.device,
|
||||
list(self.req_info.values()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
rows = torch.tensor(
|
||||
list(self.req_info.keys()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
|
||||
Reference in New Issue
Block a user