[Bugfix] Fix lm_head weights tying with lora for llama (#9227)

This commit is contained in:
Isotr0py
2024-10-10 21:11:56 +08:00
committed by GitHub
parent f3a507f1d3
commit 07c11cf4d4
2 changed files with 12 additions and 2 deletions

View File

@@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,