[Bugfix] support tie_word_embeddings for all models (#5724)

This commit is contained in:
Zijian Hu
2024-08-19 20:00:04 -07:00
committed by GitHub
parent 0df7ec0b2d
commit f4fc7337bf
30 changed files with 90 additions and 16 deletions

View File

@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.unpadded_vocab_size = config.vocab_size