[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)

This commit is contained in:
Abhinav Goyal
2024-08-22 15:12:24 +05:30
committed by GitHub
parent b3856bef7d
commit a3fce56b88
17 changed files with 854 additions and 83 deletions

View File

@@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
class Medusa(nn.Module):
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
Reference implementation: https://github.com/FasterDecoding/Medusa
Differences from reference implementation:
1. Currently this only supports generating proposals from top-1 tokens.
2. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
@@ -57,6 +70,12 @@ class Medusa(nn.Module):
self.truncated_vocab_size,
logit_scale)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: