[Bugfix] Fix spec decoding when seed is none in a batch (#10863)

Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
Wallas Henrique
2024-12-20 02:15:31 -03:00
committed by GitHub
parent b880ffb87e
commit 86c2d8fd1c
2 changed files with 66 additions and 7 deletions

View File

@@ -1,6 +1,6 @@
from functools import cached_property
from importlib.util import find_spec
from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple
import torch
import torch.jit
@@ -386,16 +386,12 @@ def _multinomial(
if not seeded_seqs:
q.exponential_(1.0)
else:
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
# Note: generator might be None for non seeded
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)