diff --git a/tests/test_inv_rope.py b/tests/test_inv_rope.py new file mode 100644 index 00000000..212a8dae --- /dev/null +++ b/tests/test_inv_rope.py @@ -0,0 +1,126 @@ +"""Test _apply_inv_rope_bf16: inverse RoPE should undo forward RoPE.""" +import torch +import math + +def apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim): + """Forward GPT-J style RoPE.""" + if rope_dim == 0 or x.numel() == 0: + return x + half_rot = rope_dim // 2 + x_f32 = x.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot:2*half_rot].to(torch.float32) + view_shape = (positions.shape[0], 1, half_rot) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = x_f32[..., nope_dim:] + y_even = rope[..., 0::2] + y_odd = rope[..., 1::2] + rope_out = torch.stack( + (y_even * cos - y_odd * sin, y_odd * cos + y_even * sin), + dim=-1, + ).flatten(-2) + x_f32 = x_f32.clone() + x_f32[..., nope_dim:] = rope_out + return x_f32.to(x.dtype) + +def apply_inv_rope_bf16(o, positions, cos_sin_cache, nope_dim, rope_dim): + """Inverse GPT-J style RoPE (sin -> -sin).""" + if rope_dim == 0 or o.numel() == 0: + return o + half_rot = rope_dim // 2 + o_f32 = o.to(torch.float32) + cache = cos_sin_cache.index_select(0, positions.to(torch.long)) + cos = cache[:, :half_rot].to(torch.float32) + sin = cache[:, half_rot:2*half_rot].to(torch.float32) + view_shape = (positions.shape[0], 1, half_rot) + cos = cos.view(view_shape) + sin = sin.view(view_shape) + rope = o_f32[..., nope_dim:] + y_even = rope[..., 0::2] + y_odd = rope[..., 1::2] + rope_out = torch.stack( + (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), + dim=-1, + ).flatten(-2) + o_f32 = o_f32.clone() + o_f32[..., nope_dim:] = rope_out + return o_f32.to(o.dtype) + + +def test_inv_rope_roundtrip(): + """Forward RoPE then inverse RoPE should be identity.""" + torch.manual_seed(42) + num_tokens = 8 + num_heads = 16 + head_dim = 512 + nope_dim = 448 + rope_dim = 64 + max_pos = 1024 + + # Build cos/sin cache (like RotaryEmbedding) + half_rot = rope_dim // 2 + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot)) + positions = torch.randint(0, max_pos, (num_tokens,)) + freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0) + cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) + sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) + cos_vals = torch.cos(freqs) + sin_vals = torch.sin(freqs) + for i, p in enumerate(positions): + cos_cache_full[p] = cos_vals[i] + sin_cache_full[p] = sin_vals[i] + cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1) + + x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16) + + # Forward RoPE + x_rope = apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim) + + # Inverse RoPE + x_recovered = apply_inv_rope_bf16(x_rope, positions, cos_sin_cache, nope_dim, rope_dim) + + # Should be identity (within BF16 precision) + diff = (x.to(torch.float32) - x_recovered.to(torch.float32)).abs().max().item() + print(f"Max abs diff: {diff:.6f}") + assert diff < 0.05, f"Roundtrip error too large: {diff}" + print("PASS: inverse RoPE roundtrip within tolerance") + + +def test_nope_dim_unchanged(): + """NoPE dimensions should be unchanged by inverse RoPE.""" + torch.manual_seed(42) + num_tokens = 4 + num_heads = 4 + head_dim = 128 + nope_dim = 96 + rope_dim = 32 + max_pos = 512 + + half_rot = rope_dim // 2 + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot)) + positions = torch.randint(0, max_pos, (num_tokens,)) + freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0) + cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) + sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) + cos_vals = torch.cos(freqs) + sin_vals = torch.sin(freqs) + for i, p in enumerate(positions): + cos_cache_full[p] = cos_vals[i] + sin_cache_full[p] = sin_vals[i] + cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1) + + x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16) + x_inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim) + + # NoPE dims should be unchanged + nope_diff = (x[..., :nope_dim].to(torch.float32) - x_inv[..., :nope_dim].to(torch.float32)).abs().max().item() + print(f"NoPE max diff: {nope_diff:.6f}") + assert nope_diff == 0.0, "NoPE dimensions should be unchanged" + print("PASS: NoPE dimensions unchanged") + + +if __name__ == "__main__": + test_inv_rope_roundtrip() + test_nope_dim_unchanged() diff --git a/tests/test_nvfp4_mapper.py b/tests/test_nvfp4_mapper.py index aac97200..ccea0914 100644 --- a/tests/test_nvfp4_mapper.py +++ b/tests/test_nvfp4_mapper.py @@ -1,204 +1,187 @@ -#!/usr/bin/env python3 -"""Unit test for the NVFP4 weights mapper. - -Validates that checkpoint key names from our ModelOpt-quantized -DeepSeek-V4-Pro checkpoint are correctly mapped to vLLM model -parameter names. - -This can run WITHOUT vLLM or CUDA — it only tests the mapper logic. -""" - +"""Test that the NVFP4 weight mapper correctly maps checkpoint keys to model parameter names.""" import re import sys -from typing import Optional +# Inline WeightsMapper (from vllm.model_executor.models.utils) +from dataclasses import dataclass, field +from typing import Any, Iterable, Mapping +@dataclass class WeightsMapper: - """Simplified WeightsMapper for testing.""" + orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict) + orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict) + orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict) + orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict) - def __init__( - self, - orig_to_new_prefix: Optional[dict] = None, - orig_to_new_regex: Optional[dict] = None, - orig_to_new_suffix: Optional[dict] = None, - orig_to_new_substr: Optional[dict] = None, - ): - self.prefix_map = orig_to_new_prefix or {} - self.regex_map = orig_to_new_regex or {} - self.suffix_map = orig_to_new_suffix or {} - self.substr_map = orig_to_new_substr or {} - - def map_name(self, name: str) -> str: - # 1. Prefix - for old, new in self.prefix_map.items(): - if name.startswith(old): - name = new + name[len(old):] - break - - # 2. Regex - for pattern, replacement in self.regex_map.items(): - name = pattern.sub(replacement, name) - - # 3. Suffix - for old, new in self.suffix_map.items(): - if name.endswith(old): - name = name[: -len(old)] + new - break - - # 4. Substr (ordered dict — specific before general) - for old, new in self.substr_map.items(): - if old in name: - name = name.replace(old, new, 1) - - return name + def _map_name(self, key: str) -> str | None: + for pattern, new_key in self.orig_to_new_regex.items(): + if pattern.search(key): + if new_key is None: + return None + key = pattern.sub(new_key, key) + for substr, new_key in self.orig_to_new_substr.items(): + if substr in key: + if new_key is None: + return None + key = key.replace(substr, new_key, 1) + for prefix, new_key in self.orig_to_new_prefix.items(): + if key.startswith(prefix): + if new_key is None: + return None + key = key.replace(prefix, new_key, 1) + for suffix, new_key in self.orig_to_new_suffix.items(): + if key.endswith(suffix): + if new_key is None: + return None + key = new_key.join(key.rsplit(suffix, 1)) + return key -def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper: - """Exact copy of the mapper from deepseek_v4.py.""" +def make_nvfp4_mapper() -> WeightsMapper: expert_rename_regex = { re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.", re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.", re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.", } - - suffix_renames = {} - - substr_renames = { - # === Compressor (non-indexer) NVFP4 renames === - "compressor.kv_proj.": "compressor.wkv.", - "compressor.gate_proj.": "compressor.wgate.", - "compressor.kv_norm.": "compressor.norm.", - "compressor.position_bias": "compressor.ape", - # === Attention compressor (before indexer renames) === - ".self_attn.compressor.": ".attn.mla_attn.compressor.", - # === Indexer params === - "compressor.indexer.q_b_proj.": "indexer.wq_b.", - "compressor.indexer.weights_proj.": "indexer.weights_proj.", - "compressor.indexer.kv_norm.": "indexer.k_norm.", - "compressor.indexer.kv_proj.": "indexer.compressor.wkv.", - "compressor.indexer.gate_proj.": "indexer.compressor.wgate.", - "compressor.indexer.position_bias": "indexer.compressor.ape", - # === Attention projections === - ".self_attn.q_a_proj.": ".attn.wq_a.", - ".self_attn.kv_proj.": ".attn.wkv.", - ".self_attn.q_b_proj.": ".attn.wq_b.", - ".self_attn.o_a_proj.": ".attn.wo_a.", - ".self_attn.o_b_proj.": ".attn.wo_b.", - ".self_attn.q_a_norm.": ".attn.q_a_norm.", - ".self_attn.kv_norm.": ".attn.kv_norm.", - ".self_attn.sinks": ".attn.sinks", - # Shared expert projections - ".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.", - ".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.", - ".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.", - # General renames - ".mlp.": ".ffn.", - ".self_attn.": ".attn.", - } - return WeightsMapper( orig_to_new_prefix={ "layers.": "model.layers.", - "embed_tokens.": "model.embed_tokens.", + "embed.": "model.embed.", "norm.": "model.norm.", "hc_head": "model.hc_head", "mtp.": "model.mtp.", }, orig_to_new_regex=expert_rename_regex, - orig_to_new_suffix=suffix_renames, - orig_to_new_substr=substr_renames, + orig_to_new_suffix={ + "head.weight": "lm_head.weight", + "embed.weight": "embed_tokens.weight", + ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", + ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", + ".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias", + ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", + }, + orig_to_new_substr={ + ".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.", + ".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.", + ".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.", + ".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.", + ".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.", + ".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape", + "compressor.kv_proj.": "compressor.wkv.", + "compressor.gate_proj.": "compressor.wgate.", + "compressor.kv_norm.": "compressor.norm.", + "compressor.position_bias": "compressor.ape", + ".self_attn.compressor.": ".attn.compressor.", + ".self_attn.q_a_proj.": ".attn.wq_a.", + ".self_attn.kv_proj.": ".attn.wkv.", + ".self_attn.q_b_proj.": ".attn.wq_b.", + ".self_attn.o_a_proj.": ".attn.wo_a.", + ".self_attn.o_b_proj.": ".attn.wo_b.", + ".self_attn.q_a_norm.": ".attn.q_norm.", + ".self_attn.kv_norm.": ".attn.kv_norm.", + ".self_attn.sinks": ".attn.attn_sink", + ".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.", + ".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.", + ".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.", + ".mlp.": ".ffn.", + ".self_attn.": ".attn.", + "input_layernorm.": "attn_norm.", + "post_attention_layernorm.": "ffn_norm.", + ".attn_hc.fn": ".hc_attn_fn", + ".attn_hc.base": ".hc_attn_base", + ".attn_hc.scale": ".hc_attn_scale", + ".ffn_hc.fn": ".hc_ffn_fn", + ".ffn_hc.base": ".hc_ffn_base", + ".ffn_hc.scale": ".hc_ffn_scale", + "hc_head.fn": "hc_head_fn", + "hc_head.base": "hc_head_base", + "hc_head.scale": "hc_head_scale", + }, ) -TEST_CASES = [ - # Embedding & top-level - ("embed_tokens.weight", "model.embed_tokens.weight"), - ("norm.weight", "model.norm.weight"), - ("hc_head.hc_fn", "model.hc_head.hc_fn"), - ("hc_head.hc_base", "model.hc_head.hc_base"), - ("hc_head.hc_scale", "model.hc_head.hc_scale"), - ("lm_head.weight", "lm_head.weight"), - - # Attention — self_attn → attn - ("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"), - ("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"), - ("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"), - ("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"), - ("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"), - ("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"), - ("layers.0.self_attn.o_b_proj.input_scale", "model.layers.0.attn.wo_b.input_scale"), - ("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_a_norm.weight"), - ("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"), - ("layers.0.self_attn.sinks", "model.layers.0.attn.sinks"), - - # Compressor (non-indexer): kv_proj → wkv, gate_proj → wgate - ("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.mla_attn.compressor.wkv.weight"), - ("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wkv.input_scale"), - ("layers.0.self_attn.compressor.kv_proj.weight_scale", "model.layers.0.attn.mla_attn.compressor.wkv.weight_scale"), - ("layers.0.self_attn.compressor.kv_proj.weight_scale_2", "model.layers.0.attn.mla_attn.compressor.wkv.weight_scale_2"), - ("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.mla_attn.compressor.wgate.weight"), - ("layers.0.self_attn.compressor.gate_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wgate.input_scale"), - ("layers.0.self_attn.compressor.gate_proj.weight_scale", "model.layers.0.attn.mla_attn.compressor.wgate.weight_scale"), - ("layers.0.self_attn.compressor.gate_proj.weight_scale_2", "model.layers.0.attn.mla_attn.compressor.wgate.weight_scale_2"), - ("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.mla_attn.compressor.norm.weight"), - ("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.mla_attn.compressor.ape"), - - # Indexer own params - ("layers.10.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.10.attn.mla_attn.indexer.wq_b.weight"), - ("layers.10.self_attn.compressor.indexer.q_b_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.wq_b.input_scale"), - ("layers.10.self_attn.compressor.indexer.weights_proj.weight", "model.layers.10.attn.mla_attn.indexer.weights_proj.weight"), - ("layers.10.self_attn.compressor.indexer.weights_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.weights_proj.input_scale"), - ("layers.10.self_attn.compressor.indexer.kv_norm.weight", "model.layers.10.attn.mla_attn.indexer.k_norm.weight"), - - # Indexer's compressor - ("layers.10.self_attn.compressor.indexer.kv_proj.weight", "model.layers.10.attn.mla_attn.indexer.compressor.wkv.weight"), - ("layers.10.self_attn.compressor.indexer.kv_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.compressor.wkv.input_scale"), - ("layers.10.self_attn.compressor.indexer.gate_proj.weight", "model.layers.10.attn.mla_attn.indexer.compressor.wgate.weight"), - ("layers.10.self_attn.compressor.indexer.gate_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.compressor.wgate.input_scale"), - ("layers.10.self_attn.compressor.indexer.position_bias", "model.layers.10.attn.mla_attn.indexer.compressor.ape"), - - # MoE gate - ("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.gate.tid2eid"), - ("layers.0.mlp.gate.weight", "model.layers.0.ffn.gate.weight"), - - # Expert weights — gate_proj → w1, up_proj → w3, down_proj → w2 - ("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"), - ("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"), - ("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"), - ("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"), - ("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"), - ("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"), - - # Shared experts - ("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"), - ("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"), - ("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"), - ("layers.0.mlp.shared_experts.gate_proj.input_scale", "model.layers.0.ffn.shared_experts.w1.input_scale"), - ("layers.0.mlp.shared_experts.down_proj.weight_scale", "model.layers.0.ffn.shared_experts.down_proj.weight_scale"), - - # Layer norm - ("layers.0.post_attention_layernorm.weight", "model.layers.0.post_attention_layernorm.weight"), -] - - -def main(): - mapper = _make_deepseek_v4_nvfp4_weights_mapper() +def test_mapper(): + mapper = make_nvfp4_mapper() + + test_cases = [ + # Attention projections + ("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"), + ("layers.0.self_attn.q_a_proj.weight_scale", "model.layers.0.attn.wq_a.weight_scale"), + ("layers.0.self_attn.q_a_proj.weight_scale_2", "model.layers.0.attn.wq_a.weight_scale_2"), + ("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"), + ("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"), + ("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"), + ("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"), + ("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"), + ("layers.0.self_attn.o_b_proj.weight_scale", "model.layers.0.attn.wo_b.weight_scale"), + ("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_norm.weight"), + ("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"), + ("layers.0.self_attn.sinks", "model.layers.0.attn.attn_sink"), + + # Compressor (non-indexer) + ("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.compressor.wkv.weight"), + ("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.compressor.wkv.input_scale"), + ("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.compressor.wgate.weight"), + ("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.compressor.norm.weight"), + ("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.compressor.ape"), + + # Indexer + ("layers.2.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.2.attn.indexer.wq_b.weight"), + ("layers.2.self_attn.compressor.indexer.weights_proj.weight", "model.layers.2.attn.indexer.weights_proj.weight"), + ("layers.2.self_attn.compressor.indexer.kv_norm.weight", "model.layers.2.attn.indexer.k_norm.weight"), + ("layers.2.self_attn.compressor.indexer.kv_proj.weight", "model.layers.2.attn.indexer.compressor.wkv.weight"), + ("layers.2.self_attn.compressor.indexer.gate_proj.weight", "model.layers.2.attn.indexer.compressor.wgate.weight"), + ("layers.2.self_attn.compressor.indexer.position_bias", "model.layers.2.attn.indexer.compressor.ape"), + + # Expert weights + ("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"), + ("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"), + ("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"), + ("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"), + ("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"), + ("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"), + + # Shared experts + ("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"), + ("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"), + ("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"), + + # MoE gate + norm + ("layers.0.mlp.gate.weight", "model.layers.0.ffn.norm_gate.gate.weight"), + ("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.norm_gate.tid2eid"), + ("layers.0.input_layernorm.weight", "model.layers.0.attn_norm.weight"), + ("layers.0.post_attention_layernorm.weight", "model.layers.0.ffn.norm_gate.norm.weight"), + + # HC params + ("layers.0.attn_hc.fn", "model.layers.0.hc_attn_fn"), + ("layers.0.ffn_hc.scale", "model.layers.0.hc_ffn_scale"), + + # Global params + ("embed.weight", "model.embed_tokens.weight"), + ("norm.weight", "model.norm.weight"), + ("hc_head.fn", "model.hc_head_fn"), + + # MTP (already uses ffn prefix in checkpoint) + ("mtp.0.ffn.experts.0.w1.weight", "model.mtp.0.ffn.experts.0.w1.weight"), + ("mtp.0.ffn_norm.weight", "model.mtp.0.ffn.norm_gate.norm.weight"), + ] + passed = 0 failed = 0 - - for ckpt_key, expected in TEST_CASES: - result = mapper.map_name(ckpt_key) + for ckpt_key, expected in test_cases: + result = mapper._map_name(ckpt_key) if result == expected: passed += 1 else: - failed += 1 print(f"FAIL: {ckpt_key}") - print(f" expected: {expected}") - print(f" got: {result}") - - print(f"\n{passed} passed, {failed} failed") - return 0 if failed == 0 else 1 + print(f" Expected: {expected}") + print(f" Got: {result}") + failed += 1 + + print(f"\n{passed}/{passed+failed} tests passed") + if failed > 0: + sys.exit(1) if __name__ == "__main__": - sys.exit(main()) + test_mapper()