[Bugfix] Fix Mamba model initialization and MLP Speculator weights loading (#10456)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-20 13:04:05 +08:00
committed by GitHub
parent 9e05252b46
commit ad44437ba3
2 changed files with 4 additions and 7 deletions

View File

@@ -193,7 +193,8 @@ class MLPSpeculator(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
param = params_dict.get(name.replace("speculator.", ""))
name = name.replace("speculator.", "")
param = params_dict.get(name)
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)