[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)
This commit is contained in:
@@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
|
||||
class Medusa(nn.Module):
|
||||
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
|
||||
Reference implementation: https://github.com/FasterDecoding/Medusa
|
||||
|
||||
Differences from reference implementation:
|
||||
1. Currently this only supports generating proposals from top-1 tokens.
|
||||
2. We have an optional token_map which reduces draft vocab to most
|
||||
frequently used tokens to give some additional speed-up by reducing
|
||||
sampling overhead. This is disabled unless the checkpoint file has
|
||||
explicit token_map tensor and config has an optional attribute
|
||||
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||
in the draft checkpoint (using key token_map). Also, the draft config
|
||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||
|
||||
def __init__(self, config: MedusaConfig, **_) -> None:
|
||||
super().__init__()
|
||||
@@ -57,6 +70,12 @@ class Medusa(nn.Module):
|
||||
self.truncated_vocab_size,
|
||||
logit_scale)
|
||||
|
||||
# Token map is a idx to token mapping to reduce the vocab size for
|
||||
# the draft model. Using smaller vocab size for draft, containing
|
||||
# only most frequent tokens reduces the speculation overhead. This
|
||||
# doesn't affect the acceptance rate much and thus gives more speed
|
||||
# -up. By default, this is disabled and is only used if the EAGLE
|
||||
# checkpoint file has token_map tensor.
|
||||
self.token_map = None
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user