[Bugfix] support tie_word_embeddings for all models (#5724)

This commit is contained in:
Zijian Hu
2024-08-19 20:00:04 -07:00
committed by GitHub
parent 0df7ec0b2d
commit f4fc7337bf
30 changed files with 90 additions and 16 deletions

View File

@@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size