[Bugfix] support tie_word_embeddings for all models (#5724)
This commit is contained in:
@@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# currently all existing command R models have `tie_word_embeddings`
|
||||
# enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
Reference in New Issue
Block a user