Add Falcon support (new) (#592)
This commit is contained in:
@@ -94,8 +94,13 @@ class ModelConfig:
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
# For GPTBigCode:
|
||||
if getattr(self.hf_config, "multi_query", False):
|
||||
# For GPTBigCode & Falcon:
|
||||
# Note: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
if (getattr(self.hf_config, "multi_query", False) and
|
||||
(self.hf_config.model_type == "falcon" and
|
||||
not getattr(self.hf_config, "new_decoder_architecture", False))):
|
||||
# Multi-query attention, only one KV head.
|
||||
return 1
|
||||
# For Falcon:
|
||||
|
||||
Reference in New Issue
Block a user