Add support for LLaMA-2 (#505)

This commit is contained in:
Zhuohan Li
2023-07-20 11:38:27 -07:00
committed by GitHub
parent c487a221ee
commit 6fc2a38b11
7 changed files with 67 additions and 38 deletions

View File

@@ -100,7 +100,12 @@ class ModelConfig:
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return self.hf_config.n_head_kv
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size