Add unit tests for NVFP4 weight mapper and inverse RoPE BF16

This commit is contained in:
2026-05-19 03:22:00 +00:00
parent b0b5113467
commit fece06f746
2 changed files with 280 additions and 171 deletions

126
tests/test_inv_rope.py Normal file
View File

@@ -0,0 +1,126 @@
"""Test _apply_inv_rope_bf16: inverse RoPE should undo forward RoPE."""
import torch
import math
def apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim):
"""Forward GPT-J style RoPE."""
if rope_dim == 0 or x.numel() == 0:
return x
half_rot = rope_dim // 2
x_f32 = x.to(torch.float32)
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
cos = cache[:, :half_rot].to(torch.float32)
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
view_shape = (positions.shape[0], 1, half_rot)
cos = cos.view(view_shape)
sin = sin.view(view_shape)
rope = x_f32[..., nope_dim:]
y_even = rope[..., 0::2]
y_odd = rope[..., 1::2]
rope_out = torch.stack(
(y_even * cos - y_odd * sin, y_odd * cos + y_even * sin),
dim=-1,
).flatten(-2)
x_f32 = x_f32.clone()
x_f32[..., nope_dim:] = rope_out
return x_f32.to(x.dtype)
def apply_inv_rope_bf16(o, positions, cos_sin_cache, nope_dim, rope_dim):
"""Inverse GPT-J style RoPE (sin -> -sin)."""
if rope_dim == 0 or o.numel() == 0:
return o
half_rot = rope_dim // 2
o_f32 = o.to(torch.float32)
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
cos = cache[:, :half_rot].to(torch.float32)
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
view_shape = (positions.shape[0], 1, half_rot)
cos = cos.view(view_shape)
sin = sin.view(view_shape)
rope = o_f32[..., nope_dim:]
y_even = rope[..., 0::2]
y_odd = rope[..., 1::2]
rope_out = torch.stack(
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
dim=-1,
).flatten(-2)
o_f32 = o_f32.clone()
o_f32[..., nope_dim:] = rope_out
return o_f32.to(o.dtype)
def test_inv_rope_roundtrip():
"""Forward RoPE then inverse RoPE should be identity."""
torch.manual_seed(42)
num_tokens = 8
num_heads = 16
head_dim = 512
nope_dim = 448
rope_dim = 64
max_pos = 1024
# Build cos/sin cache (like RotaryEmbedding)
half_rot = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
positions = torch.randint(0, max_pos, (num_tokens,))
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
cos_vals = torch.cos(freqs)
sin_vals = torch.sin(freqs)
for i, p in enumerate(positions):
cos_cache_full[p] = cos_vals[i]
sin_cache_full[p] = sin_vals[i]
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
# Forward RoPE
x_rope = apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
# Inverse RoPE
x_recovered = apply_inv_rope_bf16(x_rope, positions, cos_sin_cache, nope_dim, rope_dim)
# Should be identity (within BF16 precision)
diff = (x.to(torch.float32) - x_recovered.to(torch.float32)).abs().max().item()
print(f"Max abs diff: {diff:.6f}")
assert diff < 0.05, f"Roundtrip error too large: {diff}"
print("PASS: inverse RoPE roundtrip within tolerance")
def test_nope_dim_unchanged():
"""NoPE dimensions should be unchanged by inverse RoPE."""
torch.manual_seed(42)
num_tokens = 4
num_heads = 4
head_dim = 128
nope_dim = 96
rope_dim = 32
max_pos = 512
half_rot = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
positions = torch.randint(0, max_pos, (num_tokens,))
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
cos_vals = torch.cos(freqs)
sin_vals = torch.sin(freqs)
for i, p in enumerate(positions):
cos_cache_full[p] = cos_vals[i]
sin_cache_full[p] = sin_vals[i]
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
x_inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
# NoPE dims should be unchanged
nope_diff = (x[..., :nope_dim].to(torch.float32) - x_inv[..., :nope_dim].to(torch.float32)).abs().max().item()
print(f"NoPE max diff: {nope_diff:.6f}")
assert nope_diff == 0.0, "NoPE dimensions should be unchanged"
print("PASS: NoPE dimensions unchanged")
if __name__ == "__main__":
test_inv_rope_roundtrip()
test_nope_dim_unchanged()

View File

@@ -1,204 +1,187 @@
#!/usr/bin/env python3
"""Unit test for the NVFP4 weights mapper.
Validates that checkpoint key names from our ModelOpt-quantized
DeepSeek-V4-Pro checkpoint are correctly mapped to vLLM model
parameter names.
This can run WITHOUT vLLM or CUDA — it only tests the mapper logic.
"""
"""Test that the NVFP4 weight mapper correctly maps checkpoint keys to model parameter names."""
import re
import sys
from typing import Optional
# Inline WeightsMapper (from vllm.model_executor.models.utils)
from dataclasses import dataclass, field
from typing import Any, Iterable, Mapping
@dataclass
class WeightsMapper:
"""Simplified WeightsMapper for testing."""
orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict)
orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict)
orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict)
def __init__(
self,
orig_to_new_prefix: Optional[dict] = None,
orig_to_new_regex: Optional[dict] = None,
orig_to_new_suffix: Optional[dict] = None,
orig_to_new_substr: Optional[dict] = None,
):
self.prefix_map = orig_to_new_prefix or {}
self.regex_map = orig_to_new_regex or {}
self.suffix_map = orig_to_new_suffix or {}
self.substr_map = orig_to_new_substr or {}
def map_name(self, name: str) -> str:
# 1. Prefix
for old, new in self.prefix_map.items():
if name.startswith(old):
name = new + name[len(old):]
break
# 2. Regex
for pattern, replacement in self.regex_map.items():
name = pattern.sub(replacement, name)
# 3. Suffix
for old, new in self.suffix_map.items():
if name.endswith(old):
name = name[: -len(old)] + new
break
# 4. Substr (ordered dict — specific before general)
for old, new in self.substr_map.items():
if old in name:
name = name.replace(old, new, 1)
return name
def _map_name(self, key: str) -> str | None:
for pattern, new_key in self.orig_to_new_regex.items():
if pattern.search(key):
if new_key is None:
return None
key = pattern.sub(new_key, key)
for substr, new_key in self.orig_to_new_substr.items():
if substr in key:
if new_key is None:
return None
key = key.replace(substr, new_key, 1)
for prefix, new_key in self.orig_to_new_prefix.items():
if key.startswith(prefix):
if new_key is None:
return None
key = key.replace(prefix, new_key, 1)
for suffix, new_key in self.orig_to_new_suffix.items():
if key.endswith(suffix):
if new_key is None:
return None
key = new_key.join(key.rsplit(suffix, 1))
return key
def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
"""Exact copy of the mapper from deepseek_v4.py."""
def make_nvfp4_mapper() -> WeightsMapper:
expert_rename_regex = {
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
}
suffix_renames = {}
substr_renames = {
# === Compressor (non-indexer) NVFP4 renames ===
"compressor.kv_proj.": "compressor.wkv.",
"compressor.gate_proj.": "compressor.wgate.",
"compressor.kv_norm.": "compressor.norm.",
"compressor.position_bias": "compressor.ape",
# === Attention compressor (before indexer renames) ===
".self_attn.compressor.": ".attn.mla_attn.compressor.",
# === Indexer params ===
"compressor.indexer.q_b_proj.": "indexer.wq_b.",
"compressor.indexer.weights_proj.": "indexer.weights_proj.",
"compressor.indexer.kv_norm.": "indexer.k_norm.",
"compressor.indexer.kv_proj.": "indexer.compressor.wkv.",
"compressor.indexer.gate_proj.": "indexer.compressor.wgate.",
"compressor.indexer.position_bias": "indexer.compressor.ape",
# === Attention projections ===
".self_attn.q_a_proj.": ".attn.wq_a.",
".self_attn.kv_proj.": ".attn.wkv.",
".self_attn.q_b_proj.": ".attn.wq_b.",
".self_attn.o_a_proj.": ".attn.wo_a.",
".self_attn.o_b_proj.": ".attn.wo_b.",
".self_attn.q_a_norm.": ".attn.q_a_norm.",
".self_attn.kv_norm.": ".attn.kv_norm.",
".self_attn.sinks": ".attn.sinks",
# Shared expert projections
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
# General renames
".mlp.": ".ffn.",
".self_attn.": ".attn.",
}
return WeightsMapper(
orig_to_new_prefix={
"layers.": "model.layers.",
"embed_tokens.": "model.embed_tokens.",
"embed.": "model.embed.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
"mtp.": "model.mtp.",
},
orig_to_new_regex=expert_rename_regex,
orig_to_new_suffix=suffix_renames,
orig_to_new_substr=substr_renames,
orig_to_new_suffix={
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
},
orig_to_new_substr={
".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.",
".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.",
".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.",
".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.",
".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.",
".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape",
"compressor.kv_proj.": "compressor.wkv.",
"compressor.gate_proj.": "compressor.wgate.",
"compressor.kv_norm.": "compressor.norm.",
"compressor.position_bias": "compressor.ape",
".self_attn.compressor.": ".attn.compressor.",
".self_attn.q_a_proj.": ".attn.wq_a.",
".self_attn.kv_proj.": ".attn.wkv.",
".self_attn.q_b_proj.": ".attn.wq_b.",
".self_attn.o_a_proj.": ".attn.wo_a.",
".self_attn.o_b_proj.": ".attn.wo_b.",
".self_attn.q_a_norm.": ".attn.q_norm.",
".self_attn.kv_norm.": ".attn.kv_norm.",
".self_attn.sinks": ".attn.attn_sink",
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
".mlp.": ".ffn.",
".self_attn.": ".attn.",
"input_layernorm.": "attn_norm.",
"post_attention_layernorm.": "ffn_norm.",
".attn_hc.fn": ".hc_attn_fn",
".attn_hc.base": ".hc_attn_base",
".attn_hc.scale": ".hc_attn_scale",
".ffn_hc.fn": ".hc_ffn_fn",
".ffn_hc.base": ".hc_ffn_base",
".ffn_hc.scale": ".hc_ffn_scale",
"hc_head.fn": "hc_head_fn",
"hc_head.base": "hc_head_base",
"hc_head.scale": "hc_head_scale",
},
)
TEST_CASES = [
# Embedding & top-level
("embed_tokens.weight", "model.embed_tokens.weight"),
("norm.weight", "model.norm.weight"),
("hc_head.hc_fn", "model.hc_head.hc_fn"),
("hc_head.hc_base", "model.hc_head.hc_base"),
("hc_head.hc_scale", "model.hc_head.hc_scale"),
("lm_head.weight", "lm_head.weight"),
# Attention — self_attn → attn
("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"),
("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"),
("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"),
("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"),
("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"),
("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"),
("layers.0.self_attn.o_b_proj.input_scale", "model.layers.0.attn.wo_b.input_scale"),
("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_a_norm.weight"),
("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"),
("layers.0.self_attn.sinks", "model.layers.0.attn.sinks"),
# Compressor (non-indexer): kv_proj → wkv, gate_proj → wgate
("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.mla_attn.compressor.wkv.weight"),
("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wkv.input_scale"),
("layers.0.self_attn.compressor.kv_proj.weight_scale", "model.layers.0.attn.mla_attn.compressor.wkv.weight_scale"),
("layers.0.self_attn.compressor.kv_proj.weight_scale_2", "model.layers.0.attn.mla_attn.compressor.wkv.weight_scale_2"),
("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.mla_attn.compressor.wgate.weight"),
("layers.0.self_attn.compressor.gate_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wgate.input_scale"),
("layers.0.self_attn.compressor.gate_proj.weight_scale", "model.layers.0.attn.mla_attn.compressor.wgate.weight_scale"),
("layers.0.self_attn.compressor.gate_proj.weight_scale_2", "model.layers.0.attn.mla_attn.compressor.wgate.weight_scale_2"),
("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.mla_attn.compressor.norm.weight"),
("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.mla_attn.compressor.ape"),
# Indexer own params
("layers.10.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.10.attn.mla_attn.indexer.wq_b.weight"),
("layers.10.self_attn.compressor.indexer.q_b_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.wq_b.input_scale"),
("layers.10.self_attn.compressor.indexer.weights_proj.weight", "model.layers.10.attn.mla_attn.indexer.weights_proj.weight"),
("layers.10.self_attn.compressor.indexer.weights_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.weights_proj.input_scale"),
("layers.10.self_attn.compressor.indexer.kv_norm.weight", "model.layers.10.attn.mla_attn.indexer.k_norm.weight"),
# Indexer's compressor
("layers.10.self_attn.compressor.indexer.kv_proj.weight", "model.layers.10.attn.mla_attn.indexer.compressor.wkv.weight"),
("layers.10.self_attn.compressor.indexer.kv_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.compressor.wkv.input_scale"),
("layers.10.self_attn.compressor.indexer.gate_proj.weight", "model.layers.10.attn.mla_attn.indexer.compressor.wgate.weight"),
("layers.10.self_attn.compressor.indexer.gate_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.compressor.wgate.input_scale"),
("layers.10.self_attn.compressor.indexer.position_bias", "model.layers.10.attn.mla_attn.indexer.compressor.ape"),
# MoE gate
("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.gate.tid2eid"),
("layers.0.mlp.gate.weight", "model.layers.0.ffn.gate.weight"),
# Expert weights — gate_proj → w1, up_proj → w3, down_proj → w2
("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"),
("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"),
("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"),
("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"),
("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"),
("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"),
# Shared experts
("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"),
("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"),
("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"),
("layers.0.mlp.shared_experts.gate_proj.input_scale", "model.layers.0.ffn.shared_experts.w1.input_scale"),
("layers.0.mlp.shared_experts.down_proj.weight_scale", "model.layers.0.ffn.shared_experts.down_proj.weight_scale"),
# Layer norm
("layers.0.post_attention_layernorm.weight", "model.layers.0.post_attention_layernorm.weight"),
]
def main():
mapper = _make_deepseek_v4_nvfp4_weights_mapper()
def test_mapper():
mapper = make_nvfp4_mapper()
test_cases = [
# Attention projections
("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"),
("layers.0.self_attn.q_a_proj.weight_scale", "model.layers.0.attn.wq_a.weight_scale"),
("layers.0.self_attn.q_a_proj.weight_scale_2", "model.layers.0.attn.wq_a.weight_scale_2"),
("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"),
("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"),
("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"),
("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"),
("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"),
("layers.0.self_attn.o_b_proj.weight_scale", "model.layers.0.attn.wo_b.weight_scale"),
("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_norm.weight"),
("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"),
("layers.0.self_attn.sinks", "model.layers.0.attn.attn_sink"),
# Compressor (non-indexer)
("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.compressor.wkv.weight"),
("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.compressor.wkv.input_scale"),
("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.compressor.wgate.weight"),
("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.compressor.norm.weight"),
("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.compressor.ape"),
# Indexer
("layers.2.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.2.attn.indexer.wq_b.weight"),
("layers.2.self_attn.compressor.indexer.weights_proj.weight", "model.layers.2.attn.indexer.weights_proj.weight"),
("layers.2.self_attn.compressor.indexer.kv_norm.weight", "model.layers.2.attn.indexer.k_norm.weight"),
("layers.2.self_attn.compressor.indexer.kv_proj.weight", "model.layers.2.attn.indexer.compressor.wkv.weight"),
("layers.2.self_attn.compressor.indexer.gate_proj.weight", "model.layers.2.attn.indexer.compressor.wgate.weight"),
("layers.2.self_attn.compressor.indexer.position_bias", "model.layers.2.attn.indexer.compressor.ape"),
# Expert weights
("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"),
("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"),
("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"),
("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"),
("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"),
("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"),
# Shared experts
("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"),
("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"),
("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"),
# MoE gate + norm
("layers.0.mlp.gate.weight", "model.layers.0.ffn.norm_gate.gate.weight"),
("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.norm_gate.tid2eid"),
("layers.0.input_layernorm.weight", "model.layers.0.attn_norm.weight"),
("layers.0.post_attention_layernorm.weight", "model.layers.0.ffn.norm_gate.norm.weight"),
# HC params
("layers.0.attn_hc.fn", "model.layers.0.hc_attn_fn"),
("layers.0.ffn_hc.scale", "model.layers.0.hc_ffn_scale"),
# Global params
("embed.weight", "model.embed_tokens.weight"),
("norm.weight", "model.norm.weight"),
("hc_head.fn", "model.hc_head_fn"),
# MTP (already uses ffn prefix in checkpoint)
("mtp.0.ffn.experts.0.w1.weight", "model.mtp.0.ffn.experts.0.w1.weight"),
("mtp.0.ffn_norm.weight", "model.mtp.0.ffn.norm_gate.norm.weight"),
]
passed = 0
failed = 0
for ckpt_key, expected in TEST_CASES:
result = mapper.map_name(ckpt_key)
for ckpt_key, expected in test_cases:
result = mapper._map_name(ckpt_key)
if result == expected:
passed += 1
else:
failed += 1
print(f"FAIL: {ckpt_key}")
print(f" expected: {expected}")
print(f" got: {result}")
print(f"\n{passed} passed, {failed} failed")
return 0 if failed == 0 else 1
print(f" Expected: {expected}")
print(f" Got: {result}")
failed += 1
print(f"\n{passed}/{passed+failed} tests passed")
if failed > 0:
sys.exit(1)
if __name__ == "__main__":
sys.exit(main())
test_mapper()