"""Test that the NVFP4 weight mapper correctly maps checkpoint keys to model parameter names.""" import re import sys # Inline WeightsMapper (from vllm.model_executor.models.utils) from dataclasses import dataclass, field from typing import Any, Iterable, Mapping @dataclass class WeightsMapper: orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict) orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict) orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict) orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict) def _map_name(self, key: str) -> str | None: for pattern, new_key in self.orig_to_new_regex.items(): if pattern.search(key): if new_key is None: return None key = pattern.sub(new_key, key) for substr, new_key in self.orig_to_new_substr.items(): if substr in key: if new_key is None: return None key = key.replace(substr, new_key, 1) for prefix, new_key in self.orig_to_new_prefix.items(): if key.startswith(prefix): if new_key is None: return None key = key.replace(prefix, new_key, 1) for suffix, new_key in self.orig_to_new_suffix.items(): if key.endswith(suffix): if new_key is None: return None key = new_key.join(key.rsplit(suffix, 1)) return key def make_nvfp4_mapper() -> WeightsMapper: expert_rename_regex = { re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.", re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.", re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.", } return WeightsMapper( orig_to_new_prefix={ "layers.": "model.layers.", "embed.": "model.embed.", "norm.": "model.norm.", # hc_head NOT mapped — checkpoint already has model.hc_head.* # and model params are flat (hc_head_fn, not hc_head.fn) "mtp.": "model.mtp.", }, orig_to_new_regex=expert_rename_regex, orig_to_new_suffix={ "head.weight": "lm_head.weight", "embed.weight": "embed_tokens.weight", ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", ".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias", ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", }, orig_to_new_substr={ ".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.", ".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.", ".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.", ".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.", ".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.", ".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape", "compressor.kv_proj.": "compressor.wkv.", "compressor.gate_proj.": "compressor.wgate.", "compressor.kv_norm.": "compressor.norm.", "compressor.position_bias": "compressor.ape", ".self_attn.compressor.": ".attn.mla_attn.compressor.", ".self_attn.q_a_proj.": ".attn.wq_a.", ".self_attn.kv_proj.": ".attn.wkv.", ".self_attn.q_b_proj.": ".attn.wq_b.", ".self_attn.o_a_proj.": ".attn.wo_a.", ".self_attn.o_b_proj.": ".attn.wo_b.", ".self_attn.q_a_norm.": ".attn.q_norm.", ".self_attn.kv_norm.": ".attn.kv_norm.", ".self_attn.sinks": ".attn.attn_sink", ".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.", ".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.", ".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.", ".mlp.": ".ffn.", ".self_attn.": ".attn.", "input_layernorm.": "attn_norm.", "post_attention_layernorm.": "ffn_norm.", ".attn_hc.fn": ".hc_attn_fn", ".attn_hc.base": ".hc_attn_base", ".attn_hc.scale": ".hc_attn_scale", ".ffn_hc.fn": ".hc_ffn_fn", ".ffn_hc.base": ".hc_ffn_base", ".ffn_hc.scale": ".hc_ffn_scale", "hc_head.hc_fn": "hc_head_fn", "hc_head.hc_base": "hc_head_base", "hc_head.hc_scale": "hc_head_scale", }, ) def test_mapper(): mapper = make_nvfp4_mapper() test_cases = [ # Attention projections ("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"), ("layers.0.self_attn.q_a_proj.weight_scale", "model.layers.0.attn.wq_a.weight_scale"), ("layers.0.self_attn.q_a_proj.weight_scale_2", "model.layers.0.attn.wq_a.weight_scale_2"), ("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"), ("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"), ("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"), ("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"), ("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"), ("layers.0.self_attn.o_b_proj.weight_scale", "model.layers.0.attn.wo_b.weight_scale"), ("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_norm.weight"), ("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"), ("layers.0.self_attn.sinks", "model.layers.0.attn.attn_sink"), # Compressor (non-indexer) ("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.mla_attn.compressor.wkv.weight"), ("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wkv.input_scale"), ("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.mla_attn.compressor.wgate.weight"), ("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.mla_attn.compressor.norm.weight"), ("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.mla_attn.compressor.ape"), # Indexer ("layers.2.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.2.attn.indexer.wq_b.weight"), ("layers.2.self_attn.compressor.indexer.weights_proj.weight", "model.layers.2.attn.indexer.weights_proj.weight"), ("layers.2.self_attn.compressor.indexer.kv_norm.weight", "model.layers.2.attn.indexer.k_norm.weight"), ("layers.2.self_attn.compressor.indexer.kv_proj.weight", "model.layers.2.attn.indexer.compressor.wkv.weight"), ("layers.2.self_attn.compressor.indexer.gate_proj.weight", "model.layers.2.attn.indexer.compressor.wgate.weight"), ("layers.2.self_attn.compressor.indexer.position_bias", "model.layers.2.attn.indexer.compressor.ape"), # Expert weights ("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"), ("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"), ("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"), ("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"), ("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"), ("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"), # Shared experts ("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"), ("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"), ("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"), # MoE gate + norm ("layers.0.mlp.gate.weight", "model.layers.0.ffn.norm_gate.gate.weight"), ("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.norm_gate.tid2eid"), ("layers.0.input_layernorm.weight", "model.layers.0.attn_norm.weight"), ("layers.0.post_attention_layernorm.weight", "model.layers.0.ffn.norm_gate.norm.weight"), # HC params ("layers.0.attn_hc.fn", "model.layers.0.hc_attn_fn"), ("layers.0.ffn_hc.scale", "model.layers.0.hc_ffn_scale"), # HC head (checkpoint has model.hc_head.hc_fn, model params are flat hc_head_fn) ("hc_head.hc_fn", "hc_head_fn"), # MTP (already uses ffn prefix in checkpoint) ("mtp.0.ffn.experts.0.w1.weight", "model.mtp.0.ffn.experts.0.w1.weight"), ("mtp.0.ffn_norm.weight", "model.mtp.0.ffn.norm_gate.norm.weight"), ] passed = 0 failed = 0 for ckpt_key, expected in test_cases: result = mapper._map_name(ckpt_key) if result == expected: passed += 1 else: print(f"FAIL: {ckpt_key}") print(f" Expected: {expected}") print(f" Got: {result}") failed += 1 print(f"\n{passed}/{passed+failed} tests passed") if failed > 0: sys.exit(1) if __name__ == "__main__": test_mapper()