Add docstrings to some modules and classes (#100)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -12,6 +13,19 @@ from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""Samples the next tokens from the model's outputs.
|
||||
|
||||
This layer does the following:
|
||||
1. Discard the hidden states that are not used for sampling (i.e., all
|
||||
tokens except the final one in each prompt).
|
||||
2. Compute the logits for the next tokens.
|
||||
3. Apply presence and frequency penalties.
|
||||
4. Apply temperature scaling.
|
||||
5. Apply top-p and top-k truncation.
|
||||
6. Sample the next tokens.
|
||||
Here, each sequence group within the batch can have different sampling
|
||||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user