[spec decode] Fix MTP inference path for MiMo-7B model (#25136)

Signed-off-by: zixi-qi <qizixi@meta.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
qizixi
2025-09-18 09:12:19 -07:00
committed by GitHub
parent 1c3b1634aa
commit c4cb0af98a
3 changed files with 20 additions and 6 deletions

View File

@@ -241,6 +241,15 @@ class MiMoMTP(nn.Module):
def map_model_name_to_mtp_param_name(self, name: str) -> str:
import regex as re
# append mtp_start_layer_idx
pattern = r"(model\.mtp_layers\.)(\d+)(\.)"
match = re.match(pattern, name)
if match:
original_num = int(match.group(2))
new_num = original_num + self.config.num_hidden_layers
name = name.replace(match.group(), f"{match.group(1)}{new_num}.")
# check for early turn
name_without_prefix = [
"token_layernorm", "hidden_layernorm", "input_proj",
"final_layernorm"
@@ -248,10 +257,11 @@ class MiMoMTP(nn.Module):
for sub_name in name_without_prefix:
if sub_name in name:
return name
pattern = r"model.mtp_layers.(\d+)."
group = re.match(pattern, name)
if group is not None:
name = name.replace(group.group(), group.group() + "mtp_block.")
# add mtp_block
pattern = r"(model\.mtp_layers\.\d+\.)"
match = re.match(pattern, name)
if match:
name = name.replace(match.group(), match.group() + "mtp_block.")
return name
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: