Implement LLaMA (#9)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Woosuk Kwon
2023-03-29 21:25:32 -07:00
committed by GitHub
parent a1b3de86cd
commit 80a2f812f1
7 changed files with 500 additions and 35 deletions

View File

@@ -39,7 +39,7 @@ class Sampler(nn.Module):
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
logprobs = torch.log(probs)
logprobs = torch.log(probs, out=logits)
# Apply top-p truncation.
top_ps = _get_top_ps(input_metadata)