Optimize MQA Kernel (#452)
This commit is contained in:
@@ -94,6 +94,13 @@ class ModelConfig:
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
# For GPTBigCode:
|
||||
if getattr(self.hf_config, "multi_query", False):
|
||||
# Multi-query attention, only one KV head.
|
||||
return 1
|
||||
# For Falcon:
|
||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||
return self.hf_config.n_head_kv
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user