[spec decode] Fix MTP inference path for MiMo-7B model (#25136)
Signed-off-by: zixi-qi <qizixi@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -241,6 +241,15 @@ class MiMoMTP(nn.Module):
|
||||
|
||||
def map_model_name_to_mtp_param_name(self, name: str) -> str:
|
||||
import regex as re
|
||||
|
||||
# append mtp_start_layer_idx
|
||||
pattern = r"(model\.mtp_layers\.)(\d+)(\.)"
|
||||
match = re.match(pattern, name)
|
||||
if match:
|
||||
original_num = int(match.group(2))
|
||||
new_num = original_num + self.config.num_hidden_layers
|
||||
name = name.replace(match.group(), f"{match.group(1)}{new_num}.")
|
||||
# check for early turn
|
||||
name_without_prefix = [
|
||||
"token_layernorm", "hidden_layernorm", "input_proj",
|
||||
"final_layernorm"
|
||||
@@ -248,10 +257,11 @@ class MiMoMTP(nn.Module):
|
||||
for sub_name in name_without_prefix:
|
||||
if sub_name in name:
|
||||
return name
|
||||
pattern = r"model.mtp_layers.(\d+)."
|
||||
group = re.match(pattern, name)
|
||||
if group is not None:
|
||||
name = name.replace(group.group(), group.group() + "mtp_block.")
|
||||
# add mtp_block
|
||||
pattern = r"(model\.mtp_layers\.\d+\.)"
|
||||
match = re.match(pattern, name)
|
||||
if match:
|
||||
name = name.replace(match.group(), match.group() + "mtp_block.")
|
||||
return name
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user