[BugFix] Fix MLPSpeculator handling of num_speculative_tokens (#5876)

This commit is contained in:
Nick Hill
2024-06-27 10:59:33 -07:00
committed by GitHub
parent 3fd02bda51
commit 691e29ecf3
3 changed files with 18 additions and 10 deletions

View File

@@ -11,6 +11,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig
class MLPSpeculatorLayerNorm(nn.Module):
@@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
class MLPSpeculator(nn.Module):
def __init__(self, config, **kwargs) -> None:
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
@@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim
self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
self.n_predict)
self.max_speculative_tokens = config.num_lookahead_tokens
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
@@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
param = params_dict.get(name.replace("speculator.", ""))
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)