Optimize MQA Kernel (#452)

This commit is contained in:
Zhuohan Li
2023-07-14 20:06:40 -04:00
committed by GitHub
parent dbed69058c
commit 96853af5a8
5 changed files with 84 additions and 72 deletions

View File

@@ -94,6 +94,13 @@ class ModelConfig:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode:
if getattr(self.hf_config, "multi_query", False):
# Multi-query attention, only one KV head.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return self.hf_config.n_head_kv
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size