ChatGLM Support (#1261)
This commit is contained in:
@@ -166,6 +166,10 @@ class ModelConfig:
|
||||
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||
return (self.hf_config.num_key_value_heads //
|
||||
parallel_config.tensor_parallel_size)
|
||||
# For ChatGLM-2:
|
||||
if getattr(self.hf_config, "multi_query_group_num", None) is not None:
|
||||
return (self.hf_config.multi_query_group_num //
|
||||
parallel_config.tensor_parallel_size)
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user