[Bugfix] Fix lm_head weights tying with lora for llama (#9227)
This commit is contained in:
@@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.embed_tokens)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
|
||||
Reference in New Issue
Block a user