Fix hc_head mapping: checkpoint uses hc_head.hc_fn, model params are flat hc_head_fn

- Removed hc_head prefix mapping (checkpoint already has model.hc_head.*)
- Fixed substr: hc_head.hc_fn→hc_head_fn (not hc_head.fn→hc_head_fn)
- The model has self.hc_head_fn as flat params, not inside a sub-module
This commit is contained in:
2026-05-19 03:58:25 +00:00
parent 909a2710e4
commit 6f9a400ae0
2 changed files with 12 additions and 12 deletions

View File

@@ -48,7 +48,8 @@ def make_nvfp4_mapper() -> WeightsMapper:
"layers.": "model.layers.",
"embed.": "model.embed.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
# hc_head NOT mapped — checkpoint already has model.hc_head.*
# and model params are flat (hc_head_fn, not hc_head.fn)
"mtp.": "model.mtp.",
},
orig_to_new_regex=expert_rename_regex,
@@ -93,9 +94,9 @@ def make_nvfp4_mapper() -> WeightsMapper:
".ffn_hc.fn": ".hc_ffn_fn",
".ffn_hc.base": ".hc_ffn_base",
".ffn_hc.scale": ".hc_ffn_scale",
"hc_head.fn": "hc_head_fn",
"hc_head.base": "hc_head_base",
"hc_head.scale": "hc_head_scale",
"hc_head.hc_fn": "hc_head_fn",
"hc_head.hc_base": "hc_head_base",
"hc_head.hc_scale": "hc_head_scale",
},
)
@@ -156,10 +157,8 @@ def test_mapper():
("layers.0.attn_hc.fn", "model.layers.0.hc_attn_fn"),
("layers.0.ffn_hc.scale", "model.layers.0.hc_ffn_scale"),
# Global params
("embed.weight", "model.embed_tokens.weight"),
("norm.weight", "model.norm.weight"),
("hc_head.fn", "model.hc_head_fn"),
# HC head (checkpoint has model.hc_head.hc_fn, model params are flat hc_head_fn)
("hc_head.hc_fn", "hc_head_fn"),
# MTP (already uses ffn prefix in checkpoint)
("mtp.0.ffn.experts.0.w1.weight", "model.mtp.0.ffn.experts.0.w1.weight"),

View File

@@ -1640,7 +1640,8 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
"layers.": "model.layers.",
"embed.": "model.embed.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
# hc_head NOT mapped here — checkpoint already has model.hc_head.*
# and model params are flat (hc_head_fn, not hc_head.fn)
"mtp.": "model.mtp.",
},
orig_to_new_regex=expert_rename_regex,
@@ -1697,9 +1698,9 @@ def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
".ffn_hc.fn": ".hc_ffn_fn",
".ffn_hc.base": ".hc_ffn_base",
".ffn_hc.scale": ".hc_ffn_scale",
"hc_head.fn": "hc_head_fn",
"hc_head.base": "hc_head_base",
"hc_head.scale": "hc_head_scale",
"hc_head.hc_fn": "hc_head_fn",
"hc_head.hc_base": "hc_head_base",
"hc_head.hc_scale": "hc_head_scale",
},
)