[MODEL] Add support for Zamba2 models (#13185)
Signed-off-by: Yury Tokpanov <yury@zyphra.com> Signed-off-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Quentin Anthony <qganthony@yahoo.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -821,6 +821,11 @@ class ModelConfig:
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
return self.hf_text_config.attention_head_dim
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
|
||||
@@ -944,6 +949,15 @@ class ModelConfig:
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user