@@ -39,7 +39,7 @@ class Sampler(nn.Module):
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p).
|
||||
logprobs = torch.log(probs)
|
||||
logprobs = torch.log(probs, out=logits)
|
||||
|
||||
# Apply top-p truncation.
|
||||
top_ps = _get_top_ps(input_metadata)
|
||||
|
||||
Reference in New Issue
Block a user