[BugFix] Fix MLPSpeculator handling of num_speculative_tokens (#5876)
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import MLPSpeculatorConfig
|
||||
|
||||
|
||||
class MLPSpeculatorLayerNorm(nn.Module):
|
||||
@@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
||||
|
||||
class MLPSpeculator(nn.Module):
|
||||
|
||||
def __init__(self, config, **kwargs) -> None:
|
||||
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.n_predict = config.n_predict
|
||||
self.vocab_size = config.vocab_size
|
||||
@@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
|
||||
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
|
||||
else config.emb_dim
|
||||
|
||||
self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
|
||||
self.n_predict)
|
||||
self.max_speculative_tokens = config.num_lookahead_tokens
|
||||
|
||||
self.emb = nn.ModuleList([
|
||||
VocabParallelEmbedding(config.vocab_size,
|
||||
@@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
param = params_dict[name.replace("speculator.", "")]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
param = params_dict.get(name.replace("speculator.", ""))
|
||||
if param is not None:
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user