[Model][MiniMaxText01] Support MiniMaxText01 model inference (#13454)
Signed-off-by: qscqesze <475517977@qq.com> Co-authored-by: qingjun <qingjun@minimaxi.com> Co-authored-by: qscqesze <475517977@qq.com>
This commit is contained in:
@@ -971,26 +971,34 @@ class ModelConfig:
|
||||
return sum(not bc.attention.no_op
|
||||
for bc in block_configs[start:end])
|
||||
else:
|
||||
# Hybrid model
|
||||
# Hybrid model Jamba
|
||||
layers_block_type_value = getattr(self.hf_config,
|
||||
"layers_block_type", None)
|
||||
if layers_block_type_value is None:
|
||||
raise ValueError("The model is an hybrid without a "
|
||||
"layers_block_type in the hf_config, "
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
if layers_block_type_value is not None:
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
== "zamba2"):
|
||||
if attn_block_type:
|
||||
return sum(t == "hybrid"
|
||||
for t in layers_block_type_value[start:end])
|
||||
else:
|
||||
return self.get_num_layers(parallel_config)
|
||||
# Hybrid model Minimax
|
||||
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
|
||||
if attn_type_list:
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
|
||||
return sum(t == block_type.value
|
||||
for t in layers_block_type_value[start:end])
|
||||
if layers_block_type_value is None and attn_type_list is None:
|
||||
raise ValueError(
|
||||
"The model is an hybrid without a"
|
||||
"layers_block_type or an attn_type_list in the hf_config,"
|
||||
"cannot determine the num of "
|
||||
f"{block_type.value} layers")
|
||||
|
||||
return sum(t == 1 for t in attn_type_list[start:end])
|
||||
|
||||
def get_multimodal_config(self) -> "MultiModalConfig":
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user