Files
nvfp4-megamoe-kernel/tests/test_nvfp4_mapper.py
biondizzle 6f9a400ae0 Fix hc_head mapping: checkpoint uses hc_head.hc_fn, model params are flat hc_head_fn
- Removed hc_head prefix mapping (checkpoint already has model.hc_head.*)
- Fixed substr: hc_head.hc_fn→hc_head_fn (not hc_head.fn→hc_head_fn)
- The model has self.hc_head_fn as flat params, not inside a sub-module
2026-05-19 03:58:25 +00:00

187 lines
9.3 KiB
Python

"""Test that the NVFP4 weight mapper correctly maps checkpoint keys to model parameter names."""
import re
import sys
# Inline WeightsMapper (from vllm.model_executor.models.utils)
from dataclasses import dataclass, field
from typing import Any, Iterable, Mapping
@dataclass
class WeightsMapper:
orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict)
orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict)
def _map_name(self, key: str) -> str | None:
for pattern, new_key in self.orig_to_new_regex.items():
if pattern.search(key):
if new_key is None:
return None
key = pattern.sub(new_key, key)
for substr, new_key in self.orig_to_new_substr.items():
if substr in key:
if new_key is None:
return None
key = key.replace(substr, new_key, 1)
for prefix, new_key in self.orig_to_new_prefix.items():
if key.startswith(prefix):
if new_key is None:
return None
key = key.replace(prefix, new_key, 1)
for suffix, new_key in self.orig_to_new_suffix.items():
if key.endswith(suffix):
if new_key is None:
return None
key = new_key.join(key.rsplit(suffix, 1))
return key
def make_nvfp4_mapper() -> WeightsMapper:
expert_rename_regex = {
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
}
return WeightsMapper(
orig_to_new_prefix={
"layers.": "model.layers.",
"embed.": "model.embed.",
"norm.": "model.norm.",
# hc_head NOT mapped — checkpoint already has model.hc_head.*
# and model params are flat (hc_head_fn, not hc_head.fn)
"mtp.": "model.mtp.",
},
orig_to_new_regex=expert_rename_regex,
orig_to_new_suffix={
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
},
orig_to_new_substr={
".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.",
".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.",
".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.",
".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.",
".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.",
".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape",
"compressor.kv_proj.": "compressor.wkv.",
"compressor.gate_proj.": "compressor.wgate.",
"compressor.kv_norm.": "compressor.norm.",
"compressor.position_bias": "compressor.ape",
".self_attn.compressor.": ".attn.mla_attn.compressor.",
".self_attn.q_a_proj.": ".attn.wq_a.",
".self_attn.kv_proj.": ".attn.wkv.",
".self_attn.q_b_proj.": ".attn.wq_b.",
".self_attn.o_a_proj.": ".attn.wo_a.",
".self_attn.o_b_proj.": ".attn.wo_b.",
".self_attn.q_a_norm.": ".attn.q_norm.",
".self_attn.kv_norm.": ".attn.kv_norm.",
".self_attn.sinks": ".attn.attn_sink",
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
".mlp.": ".ffn.",
".self_attn.": ".attn.",
"input_layernorm.": "attn_norm.",
"post_attention_layernorm.": "ffn_norm.",
".attn_hc.fn": ".hc_attn_fn",
".attn_hc.base": ".hc_attn_base",
".attn_hc.scale": ".hc_attn_scale",
".ffn_hc.fn": ".hc_ffn_fn",
".ffn_hc.base": ".hc_ffn_base",
".ffn_hc.scale": ".hc_ffn_scale",
"hc_head.hc_fn": "hc_head_fn",
"hc_head.hc_base": "hc_head_base",
"hc_head.hc_scale": "hc_head_scale",
},
)
def test_mapper():
mapper = make_nvfp4_mapper()
test_cases = [
# Attention projections
("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"),
("layers.0.self_attn.q_a_proj.weight_scale", "model.layers.0.attn.wq_a.weight_scale"),
("layers.0.self_attn.q_a_proj.weight_scale_2", "model.layers.0.attn.wq_a.weight_scale_2"),
("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"),
("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"),
("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"),
("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"),
("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"),
("layers.0.self_attn.o_b_proj.weight_scale", "model.layers.0.attn.wo_b.weight_scale"),
("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_norm.weight"),
("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"),
("layers.0.self_attn.sinks", "model.layers.0.attn.attn_sink"),
# Compressor (non-indexer)
("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.mla_attn.compressor.wkv.weight"),
("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wkv.input_scale"),
("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.mla_attn.compressor.wgate.weight"),
("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.mla_attn.compressor.norm.weight"),
("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.mla_attn.compressor.ape"),
# Indexer
("layers.2.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.2.attn.indexer.wq_b.weight"),
("layers.2.self_attn.compressor.indexer.weights_proj.weight", "model.layers.2.attn.indexer.weights_proj.weight"),
("layers.2.self_attn.compressor.indexer.kv_norm.weight", "model.layers.2.attn.indexer.k_norm.weight"),
("layers.2.self_attn.compressor.indexer.kv_proj.weight", "model.layers.2.attn.indexer.compressor.wkv.weight"),
("layers.2.self_attn.compressor.indexer.gate_proj.weight", "model.layers.2.attn.indexer.compressor.wgate.weight"),
("layers.2.self_attn.compressor.indexer.position_bias", "model.layers.2.attn.indexer.compressor.ape"),
# Expert weights
("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"),
("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"),
("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"),
("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"),
("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"),
("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"),
# Shared experts
("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"),
("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"),
("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"),
# MoE gate + norm
("layers.0.mlp.gate.weight", "model.layers.0.ffn.norm_gate.gate.weight"),
("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.norm_gate.tid2eid"),
("layers.0.input_layernorm.weight", "model.layers.0.attn_norm.weight"),
("layers.0.post_attention_layernorm.weight", "model.layers.0.ffn.norm_gate.norm.weight"),
# HC params
("layers.0.attn_hc.fn", "model.layers.0.hc_attn_fn"),
("layers.0.ffn_hc.scale", "model.layers.0.hc_ffn_scale"),
# HC head (checkpoint has model.hc_head.hc_fn, model params are flat hc_head_fn)
("hc_head.hc_fn", "hc_head_fn"),
# MTP (already uses ffn prefix in checkpoint)
("mtp.0.ffn.experts.0.w1.weight", "model.mtp.0.ffn.experts.0.w1.weight"),
("mtp.0.ffn_norm.weight", "model.mtp.0.ffn.norm_gate.norm.weight"),
]
passed = 0
failed = 0
for ckpt_key, expected in test_cases:
result = mapper._map_name(ckpt_key)
if result == expected:
passed += 1
else:
print(f"FAIL: {ckpt_key}")
print(f" Expected: {expected}")
print(f" Got: {result}")
failed += 1
print(f"\n{passed}/{passed+failed} tests passed")
if failed > 0:
sys.exit(1)
if __name__ == "__main__":
test_mapper()