Add Falcon support (new) (#592)

This commit is contained in:
Zhuohan Li
2023-08-02 14:04:39 -07:00
committed by GitHub
parent 20044cab7a
commit 1b0bd0fe8a
16 changed files with 680 additions and 122 deletions

View File

@@ -94,8 +94,13 @@ class ModelConfig:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode:
if getattr(self.hf_config, "multi_query", False):
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
if (getattr(self.hf_config, "multi_query", False) and
(self.hf_config.model_type == "falcon" and
not getattr(self.hf_config, "new_decoder_architecture", False))):
# Multi-query attention, only one KV head.
return 1
# For Falcon: