188 lines
9.2 KiB
Python
188 lines
9.2 KiB
Python
"""Test that the NVFP4 weight mapper correctly maps checkpoint keys to model parameter names."""
|
|
import re
|
|
import sys
|
|
|
|
# Inline WeightsMapper (from vllm.model_executor.models.utils)
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Iterable, Mapping
|
|
|
|
@dataclass
|
|
class WeightsMapper:
|
|
orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict)
|
|
orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict)
|
|
orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict)
|
|
orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict)
|
|
|
|
def _map_name(self, key: str) -> str | None:
|
|
for pattern, new_key in self.orig_to_new_regex.items():
|
|
if pattern.search(key):
|
|
if new_key is None:
|
|
return None
|
|
key = pattern.sub(new_key, key)
|
|
for substr, new_key in self.orig_to_new_substr.items():
|
|
if substr in key:
|
|
if new_key is None:
|
|
return None
|
|
key = key.replace(substr, new_key, 1)
|
|
for prefix, new_key in self.orig_to_new_prefix.items():
|
|
if key.startswith(prefix):
|
|
if new_key is None:
|
|
return None
|
|
key = key.replace(prefix, new_key, 1)
|
|
for suffix, new_key in self.orig_to_new_suffix.items():
|
|
if key.endswith(suffix):
|
|
if new_key is None:
|
|
return None
|
|
key = new_key.join(key.rsplit(suffix, 1))
|
|
return key
|
|
|
|
|
|
def make_nvfp4_mapper() -> WeightsMapper:
|
|
expert_rename_regex = {
|
|
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
|
|
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
|
|
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
|
|
}
|
|
return WeightsMapper(
|
|
orig_to_new_prefix={
|
|
"layers.": "model.layers.",
|
|
"embed.": "model.embed.",
|
|
"norm.": "model.norm.",
|
|
"hc_head": "model.hc_head",
|
|
"mtp.": "model.mtp.",
|
|
},
|
|
orig_to_new_regex=expert_rename_regex,
|
|
orig_to_new_suffix={
|
|
"head.weight": "lm_head.weight",
|
|
"embed.weight": "embed_tokens.weight",
|
|
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
|
|
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
|
|
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
|
|
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
|
|
},
|
|
orig_to_new_substr={
|
|
".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.",
|
|
".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.",
|
|
".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.",
|
|
".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.",
|
|
".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.",
|
|
".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape",
|
|
"compressor.kv_proj.": "compressor.wkv.",
|
|
"compressor.gate_proj.": "compressor.wgate.",
|
|
"compressor.kv_norm.": "compressor.norm.",
|
|
"compressor.position_bias": "compressor.ape",
|
|
".self_attn.compressor.": ".attn.compressor.",
|
|
".self_attn.q_a_proj.": ".attn.wq_a.",
|
|
".self_attn.kv_proj.": ".attn.wkv.",
|
|
".self_attn.q_b_proj.": ".attn.wq_b.",
|
|
".self_attn.o_a_proj.": ".attn.wo_a.",
|
|
".self_attn.o_b_proj.": ".attn.wo_b.",
|
|
".self_attn.q_a_norm.": ".attn.q_norm.",
|
|
".self_attn.kv_norm.": ".attn.kv_norm.",
|
|
".self_attn.sinks": ".attn.attn_sink",
|
|
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
|
|
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
|
|
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
|
|
".mlp.": ".ffn.",
|
|
".self_attn.": ".attn.",
|
|
"input_layernorm.": "attn_norm.",
|
|
"post_attention_layernorm.": "ffn_norm.",
|
|
".attn_hc.fn": ".hc_attn_fn",
|
|
".attn_hc.base": ".hc_attn_base",
|
|
".attn_hc.scale": ".hc_attn_scale",
|
|
".ffn_hc.fn": ".hc_ffn_fn",
|
|
".ffn_hc.base": ".hc_ffn_base",
|
|
".ffn_hc.scale": ".hc_ffn_scale",
|
|
"hc_head.fn": "hc_head_fn",
|
|
"hc_head.base": "hc_head_base",
|
|
"hc_head.scale": "hc_head_scale",
|
|
},
|
|
)
|
|
|
|
|
|
def test_mapper():
|
|
mapper = make_nvfp4_mapper()
|
|
|
|
test_cases = [
|
|
# Attention projections
|
|
("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"),
|
|
("layers.0.self_attn.q_a_proj.weight_scale", "model.layers.0.attn.wq_a.weight_scale"),
|
|
("layers.0.self_attn.q_a_proj.weight_scale_2", "model.layers.0.attn.wq_a.weight_scale_2"),
|
|
("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"),
|
|
("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"),
|
|
("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"),
|
|
("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"),
|
|
("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"),
|
|
("layers.0.self_attn.o_b_proj.weight_scale", "model.layers.0.attn.wo_b.weight_scale"),
|
|
("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_norm.weight"),
|
|
("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"),
|
|
("layers.0.self_attn.sinks", "model.layers.0.attn.attn_sink"),
|
|
|
|
# Compressor (non-indexer)
|
|
("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.compressor.wkv.weight"),
|
|
("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.compressor.wkv.input_scale"),
|
|
("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.compressor.wgate.weight"),
|
|
("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.compressor.norm.weight"),
|
|
("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.compressor.ape"),
|
|
|
|
# Indexer
|
|
("layers.2.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.2.attn.indexer.wq_b.weight"),
|
|
("layers.2.self_attn.compressor.indexer.weights_proj.weight", "model.layers.2.attn.indexer.weights_proj.weight"),
|
|
("layers.2.self_attn.compressor.indexer.kv_norm.weight", "model.layers.2.attn.indexer.k_norm.weight"),
|
|
("layers.2.self_attn.compressor.indexer.kv_proj.weight", "model.layers.2.attn.indexer.compressor.wkv.weight"),
|
|
("layers.2.self_attn.compressor.indexer.gate_proj.weight", "model.layers.2.attn.indexer.compressor.wgate.weight"),
|
|
("layers.2.self_attn.compressor.indexer.position_bias", "model.layers.2.attn.indexer.compressor.ape"),
|
|
|
|
# Expert weights
|
|
("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"),
|
|
("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"),
|
|
("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"),
|
|
("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"),
|
|
("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"),
|
|
("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"),
|
|
|
|
# Shared experts
|
|
("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"),
|
|
("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"),
|
|
("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"),
|
|
|
|
# MoE gate + norm
|
|
("layers.0.mlp.gate.weight", "model.layers.0.ffn.norm_gate.gate.weight"),
|
|
("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.norm_gate.tid2eid"),
|
|
("layers.0.input_layernorm.weight", "model.layers.0.attn_norm.weight"),
|
|
("layers.0.post_attention_layernorm.weight", "model.layers.0.ffn.norm_gate.norm.weight"),
|
|
|
|
# HC params
|
|
("layers.0.attn_hc.fn", "model.layers.0.hc_attn_fn"),
|
|
("layers.0.ffn_hc.scale", "model.layers.0.hc_ffn_scale"),
|
|
|
|
# Global params
|
|
("embed.weight", "model.embed_tokens.weight"),
|
|
("norm.weight", "model.norm.weight"),
|
|
("hc_head.fn", "model.hc_head_fn"),
|
|
|
|
# MTP (already uses ffn prefix in checkpoint)
|
|
("mtp.0.ffn.experts.0.w1.weight", "model.mtp.0.ffn.experts.0.w1.weight"),
|
|
("mtp.0.ffn_norm.weight", "model.mtp.0.ffn.norm_gate.norm.weight"),
|
|
]
|
|
|
|
passed = 0
|
|
failed = 0
|
|
for ckpt_key, expected in test_cases:
|
|
result = mapper._map_name(ckpt_key)
|
|
if result == expected:
|
|
passed += 1
|
|
else:
|
|
print(f"FAIL: {ckpt_key}")
|
|
print(f" Expected: {expected}")
|
|
print(f" Got: {result}")
|
|
failed += 1
|
|
|
|
print(f"\n{passed}/{passed+failed} tests passed")
|
|
if failed > 0:
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_mapper()
|