[Bugfix] Weight loading fix for OPT model (#9042)

Co-authored-by: dvres <dvres@fri.uni-lj.si>
This commit is contained in:
Domen Vreš
2024-10-04 01:53:29 +02:00
committed by GitHub
parent 91add85ec4
commit 2838d6b38e

View File

@@ -353,7 +353,7 @@ class OPTForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
if "lm_head.weight" in name and self.config.tie_word_embeddings:
continue
if name.startswith("decoder."):
name = "model." + name