[Model] MLPSpeculator speculative decoding support (#4947)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Davis Wertheimer <Davis.Wertheimer@ibm.com>
This commit is contained in:
Joshua Rosenkranz
2024-06-20 20:23:12 -04:00
committed by GitHub
parent 6c5b7af152
commit b12518d3cf
18 changed files with 523 additions and 40 deletions

View File

@@ -5,6 +5,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
__all__ = [
@@ -13,4 +14,5 @@ __all__ = [
"MPTConfig",
"RWConfig",
"JAISConfig",
"MLPSpeculatorConfig",
]

View File

@@ -0,0 +1,50 @@
from typing import List, Optional
from transformers import PretrainedConfig
class MLPSpeculatorConfig(PretrainedConfig):
model_type = "mlp_speculator"
attribute_map = {
"hidden_size": "emb_dim",
}
def __init__(self,
vocab_size: int = 32000,
emb_dim: int = 4096,
inner_dim: int = 0,
n_predict: int = 3,
top_k_tokens_per_head: Optional[List[int]] = None,
n_candidates: int = 5,
**kwargs):
"""
Initialize an MLPSpeculatorConfig
Args:
vocab_size: int
the model vocab size
emb_dim: int
the model embedding dimension
inner_dim: int
the inner dimension of the model. If 0, will be the emb_dim.
n_predict: int
the number of lookaheads for the speculator
top_k_tokens_per_head: List[int]
Number of tokens to consider from each head when forming the
candidate tree.
For each candidate branch in the tree, head n produces topk[n]
additional sub-branches.
n_candidates: int
number of child candidates to create per sequence
"""
if top_k_tokens_per_head is None:
top_k_tokens_per_head = [5, 4, 3]
assert len(top_k_tokens_per_head) == n_predict
self.vocab_size = vocab_size
self.emb_dim = emb_dim
self.inner_dim = inner_dim
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
super().__init__(**kwargs)