[Bugfix] Fix spec decoding when seed is none in a batch (#10863)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@@ -386,16 +386,12 @@ def _multinomial(
|
||||
if not seeded_seqs:
|
||||
q.exponential_(1.0)
|
||||
else:
|
||||
non_seeded_indices: List[int] = []
|
||||
start = 0
|
||||
for idx in range(len(q) // k):
|
||||
end = start + k
|
||||
generator = seeded_seqs.get(idx)
|
||||
if generator is None:
|
||||
non_seeded_indices.extend(list(range(start, end)))
|
||||
else:
|
||||
q[start:end].exponential_(1.0, generator=generator)
|
||||
# Note: generator might be None for non seeded
|
||||
q[start:end].exponential_(1.0, generator=generator)
|
||||
start = end
|
||||
q[non_seeded_indices].exponential_(1.0)
|
||||
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
|
||||
Reference in New Issue
Block a user