Add unit tests for NVFP4 weight mapper and inverse RoPE BF16
This commit is contained in:
126
tests/test_inv_rope.py
Normal file
126
tests/test_inv_rope.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Test _apply_inv_rope_bf16: inverse RoPE should undo forward RoPE."""
|
||||
import torch
|
||||
import math
|
||||
|
||||
def apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim):
|
||||
"""Forward GPT-J style RoPE."""
|
||||
if rope_dim == 0 or x.numel() == 0:
|
||||
return x
|
||||
half_rot = rope_dim // 2
|
||||
x_f32 = x.to(torch.float32)
|
||||
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
|
||||
cos = cache[:, :half_rot].to(torch.float32)
|
||||
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
|
||||
view_shape = (positions.shape[0], 1, half_rot)
|
||||
cos = cos.view(view_shape)
|
||||
sin = sin.view(view_shape)
|
||||
rope = x_f32[..., nope_dim:]
|
||||
y_even = rope[..., 0::2]
|
||||
y_odd = rope[..., 1::2]
|
||||
rope_out = torch.stack(
|
||||
(y_even * cos - y_odd * sin, y_odd * cos + y_even * sin),
|
||||
dim=-1,
|
||||
).flatten(-2)
|
||||
x_f32 = x_f32.clone()
|
||||
x_f32[..., nope_dim:] = rope_out
|
||||
return x_f32.to(x.dtype)
|
||||
|
||||
def apply_inv_rope_bf16(o, positions, cos_sin_cache, nope_dim, rope_dim):
|
||||
"""Inverse GPT-J style RoPE (sin -> -sin)."""
|
||||
if rope_dim == 0 or o.numel() == 0:
|
||||
return o
|
||||
half_rot = rope_dim // 2
|
||||
o_f32 = o.to(torch.float32)
|
||||
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
|
||||
cos = cache[:, :half_rot].to(torch.float32)
|
||||
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
|
||||
view_shape = (positions.shape[0], 1, half_rot)
|
||||
cos = cos.view(view_shape)
|
||||
sin = sin.view(view_shape)
|
||||
rope = o_f32[..., nope_dim:]
|
||||
y_even = rope[..., 0::2]
|
||||
y_odd = rope[..., 1::2]
|
||||
rope_out = torch.stack(
|
||||
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
|
||||
dim=-1,
|
||||
).flatten(-2)
|
||||
o_f32 = o_f32.clone()
|
||||
o_f32[..., nope_dim:] = rope_out
|
||||
return o_f32.to(o.dtype)
|
||||
|
||||
|
||||
def test_inv_rope_roundtrip():
|
||||
"""Forward RoPE then inverse RoPE should be identity."""
|
||||
torch.manual_seed(42)
|
||||
num_tokens = 8
|
||||
num_heads = 16
|
||||
head_dim = 512
|
||||
nope_dim = 448
|
||||
rope_dim = 64
|
||||
max_pos = 1024
|
||||
|
||||
# Build cos/sin cache (like RotaryEmbedding)
|
||||
half_rot = rope_dim // 2
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
|
||||
positions = torch.randint(0, max_pos, (num_tokens,))
|
||||
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
|
||||
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
||||
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
||||
cos_vals = torch.cos(freqs)
|
||||
sin_vals = torch.sin(freqs)
|
||||
for i, p in enumerate(positions):
|
||||
cos_cache_full[p] = cos_vals[i]
|
||||
sin_cache_full[p] = sin_vals[i]
|
||||
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
|
||||
|
||||
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
|
||||
|
||||
# Forward RoPE
|
||||
x_rope = apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
||||
|
||||
# Inverse RoPE
|
||||
x_recovered = apply_inv_rope_bf16(x_rope, positions, cos_sin_cache, nope_dim, rope_dim)
|
||||
|
||||
# Should be identity (within BF16 precision)
|
||||
diff = (x.to(torch.float32) - x_recovered.to(torch.float32)).abs().max().item()
|
||||
print(f"Max abs diff: {diff:.6f}")
|
||||
assert diff < 0.05, f"Roundtrip error too large: {diff}"
|
||||
print("PASS: inverse RoPE roundtrip within tolerance")
|
||||
|
||||
|
||||
def test_nope_dim_unchanged():
|
||||
"""NoPE dimensions should be unchanged by inverse RoPE."""
|
||||
torch.manual_seed(42)
|
||||
num_tokens = 4
|
||||
num_heads = 4
|
||||
head_dim = 128
|
||||
nope_dim = 96
|
||||
rope_dim = 32
|
||||
max_pos = 512
|
||||
|
||||
half_rot = rope_dim // 2
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
|
||||
positions = torch.randint(0, max_pos, (num_tokens,))
|
||||
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
|
||||
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
||||
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
||||
cos_vals = torch.cos(freqs)
|
||||
sin_vals = torch.sin(freqs)
|
||||
for i, p in enumerate(positions):
|
||||
cos_cache_full[p] = cos_vals[i]
|
||||
sin_cache_full[p] = sin_vals[i]
|
||||
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
|
||||
|
||||
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
|
||||
x_inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
||||
|
||||
# NoPE dims should be unchanged
|
||||
nope_diff = (x[..., :nope_dim].to(torch.float32) - x_inv[..., :nope_dim].to(torch.float32)).abs().max().item()
|
||||
print(f"NoPE max diff: {nope_diff:.6f}")
|
||||
assert nope_diff == 0.0, "NoPE dimensions should be unchanged"
|
||||
print("PASS: NoPE dimensions unchanged")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_inv_rope_roundtrip()
|
||||
test_nope_dim_unchanged()
|
||||
@@ -1,204 +1,187 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Unit test for the NVFP4 weights mapper.
|
||||
|
||||
Validates that checkpoint key names from our ModelOpt-quantized
|
||||
DeepSeek-V4-Pro checkpoint are correctly mapped to vLLM model
|
||||
parameter names.
|
||||
|
||||
This can run WITHOUT vLLM or CUDA — it only tests the mapper logic.
|
||||
"""
|
||||
|
||||
"""Test that the NVFP4 weight mapper correctly maps checkpoint keys to model parameter names."""
|
||||
import re
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
# Inline WeightsMapper (from vllm.model_executor.models.utils)
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
@dataclass
|
||||
class WeightsMapper:
|
||||
"""Simplified WeightsMapper for testing."""
|
||||
orig_to_new_regex: Mapping[re.Pattern, str | None] = field(default_factory=dict)
|
||||
orig_to_new_substr: Mapping[str, str | None] = field(default_factory=dict)
|
||||
orig_to_new_prefix: Mapping[str, str | None] = field(default_factory=dict)
|
||||
orig_to_new_suffix: Mapping[str, str | None] = field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orig_to_new_prefix: Optional[dict] = None,
|
||||
orig_to_new_regex: Optional[dict] = None,
|
||||
orig_to_new_suffix: Optional[dict] = None,
|
||||
orig_to_new_substr: Optional[dict] = None,
|
||||
):
|
||||
self.prefix_map = orig_to_new_prefix or {}
|
||||
self.regex_map = orig_to_new_regex or {}
|
||||
self.suffix_map = orig_to_new_suffix or {}
|
||||
self.substr_map = orig_to_new_substr or {}
|
||||
|
||||
def map_name(self, name: str) -> str:
|
||||
# 1. Prefix
|
||||
for old, new in self.prefix_map.items():
|
||||
if name.startswith(old):
|
||||
name = new + name[len(old):]
|
||||
break
|
||||
|
||||
# 2. Regex
|
||||
for pattern, replacement in self.regex_map.items():
|
||||
name = pattern.sub(replacement, name)
|
||||
|
||||
# 3. Suffix
|
||||
for old, new in self.suffix_map.items():
|
||||
if name.endswith(old):
|
||||
name = name[: -len(old)] + new
|
||||
break
|
||||
|
||||
# 4. Substr (ordered dict — specific before general)
|
||||
for old, new in self.substr_map.items():
|
||||
if old in name:
|
||||
name = name.replace(old, new, 1)
|
||||
|
||||
return name
|
||||
def _map_name(self, key: str) -> str | None:
|
||||
for pattern, new_key in self.orig_to_new_regex.items():
|
||||
if pattern.search(key):
|
||||
if new_key is None:
|
||||
return None
|
||||
key = pattern.sub(new_key, key)
|
||||
for substr, new_key in self.orig_to_new_substr.items():
|
||||
if substr in key:
|
||||
if new_key is None:
|
||||
return None
|
||||
key = key.replace(substr, new_key, 1)
|
||||
for prefix, new_key in self.orig_to_new_prefix.items():
|
||||
if key.startswith(prefix):
|
||||
if new_key is None:
|
||||
return None
|
||||
key = key.replace(prefix, new_key, 1)
|
||||
for suffix, new_key in self.orig_to_new_suffix.items():
|
||||
if key.endswith(suffix):
|
||||
if new_key is None:
|
||||
return None
|
||||
key = new_key.join(key.rsplit(suffix, 1))
|
||||
return key
|
||||
|
||||
|
||||
def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
||||
"""Exact copy of the mapper from deepseek_v4.py."""
|
||||
def make_nvfp4_mapper() -> WeightsMapper:
|
||||
expert_rename_regex = {
|
||||
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
|
||||
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
|
||||
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
|
||||
}
|
||||
|
||||
suffix_renames = {}
|
||||
|
||||
substr_renames = {
|
||||
# === Compressor (non-indexer) NVFP4 renames ===
|
||||
"compressor.kv_proj.": "compressor.wkv.",
|
||||
"compressor.gate_proj.": "compressor.wgate.",
|
||||
"compressor.kv_norm.": "compressor.norm.",
|
||||
"compressor.position_bias": "compressor.ape",
|
||||
# === Attention compressor (before indexer renames) ===
|
||||
".self_attn.compressor.": ".attn.mla_attn.compressor.",
|
||||
# === Indexer params ===
|
||||
"compressor.indexer.q_b_proj.": "indexer.wq_b.",
|
||||
"compressor.indexer.weights_proj.": "indexer.weights_proj.",
|
||||
"compressor.indexer.kv_norm.": "indexer.k_norm.",
|
||||
"compressor.indexer.kv_proj.": "indexer.compressor.wkv.",
|
||||
"compressor.indexer.gate_proj.": "indexer.compressor.wgate.",
|
||||
"compressor.indexer.position_bias": "indexer.compressor.ape",
|
||||
# === Attention projections ===
|
||||
".self_attn.q_a_proj.": ".attn.wq_a.",
|
||||
".self_attn.kv_proj.": ".attn.wkv.",
|
||||
".self_attn.q_b_proj.": ".attn.wq_b.",
|
||||
".self_attn.o_a_proj.": ".attn.wo_a.",
|
||||
".self_attn.o_b_proj.": ".attn.wo_b.",
|
||||
".self_attn.q_a_norm.": ".attn.q_a_norm.",
|
||||
".self_attn.kv_norm.": ".attn.kv_norm.",
|
||||
".self_attn.sinks": ".attn.sinks",
|
||||
# Shared expert projections
|
||||
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
|
||||
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
|
||||
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
|
||||
# General renames
|
||||
".mlp.": ".ffn.",
|
||||
".self_attn.": ".attn.",
|
||||
}
|
||||
|
||||
return WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"layers.": "model.layers.",
|
||||
"embed_tokens.": "model.embed_tokens.",
|
||||
"embed.": "model.embed.",
|
||||
"norm.": "model.norm.",
|
||||
"hc_head": "model.hc_head",
|
||||
"mtp.": "model.mtp.",
|
||||
},
|
||||
orig_to_new_regex=expert_rename_regex,
|
||||
orig_to_new_suffix=suffix_renames,
|
||||
orig_to_new_substr=substr_renames,
|
||||
orig_to_new_suffix={
|
||||
"head.weight": "lm_head.weight",
|
||||
"embed.weight": "embed_tokens.weight",
|
||||
".ffn_norm.weight": ".ffn.norm_gate.norm.weight",
|
||||
".ffn.gate.weight": ".ffn.norm_gate.gate.weight",
|
||||
".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias",
|
||||
".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid",
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.",
|
||||
".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.",
|
||||
".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.",
|
||||
".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.",
|
||||
".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.",
|
||||
".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape",
|
||||
"compressor.kv_proj.": "compressor.wkv.",
|
||||
"compressor.gate_proj.": "compressor.wgate.",
|
||||
"compressor.kv_norm.": "compressor.norm.",
|
||||
"compressor.position_bias": "compressor.ape",
|
||||
".self_attn.compressor.": ".attn.compressor.",
|
||||
".self_attn.q_a_proj.": ".attn.wq_a.",
|
||||
".self_attn.kv_proj.": ".attn.wkv.",
|
||||
".self_attn.q_b_proj.": ".attn.wq_b.",
|
||||
".self_attn.o_a_proj.": ".attn.wo_a.",
|
||||
".self_attn.o_b_proj.": ".attn.wo_b.",
|
||||
".self_attn.q_a_norm.": ".attn.q_norm.",
|
||||
".self_attn.kv_norm.": ".attn.kv_norm.",
|
||||
".self_attn.sinks": ".attn.attn_sink",
|
||||
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
|
||||
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
|
||||
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
|
||||
".mlp.": ".ffn.",
|
||||
".self_attn.": ".attn.",
|
||||
"input_layernorm.": "attn_norm.",
|
||||
"post_attention_layernorm.": "ffn_norm.",
|
||||
".attn_hc.fn": ".hc_attn_fn",
|
||||
".attn_hc.base": ".hc_attn_base",
|
||||
".attn_hc.scale": ".hc_attn_scale",
|
||||
".ffn_hc.fn": ".hc_ffn_fn",
|
||||
".ffn_hc.base": ".hc_ffn_base",
|
||||
".ffn_hc.scale": ".hc_ffn_scale",
|
||||
"hc_head.fn": "hc_head_fn",
|
||||
"hc_head.base": "hc_head_base",
|
||||
"hc_head.scale": "hc_head_scale",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
# Embedding & top-level
|
||||
("embed_tokens.weight", "model.embed_tokens.weight"),
|
||||
("norm.weight", "model.norm.weight"),
|
||||
("hc_head.hc_fn", "model.hc_head.hc_fn"),
|
||||
("hc_head.hc_base", "model.hc_head.hc_base"),
|
||||
("hc_head.hc_scale", "model.hc_head.hc_scale"),
|
||||
("lm_head.weight", "lm_head.weight"),
|
||||
|
||||
# Attention — self_attn → attn
|
||||
("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"),
|
||||
("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"),
|
||||
("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"),
|
||||
("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"),
|
||||
("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"),
|
||||
("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"),
|
||||
("layers.0.self_attn.o_b_proj.input_scale", "model.layers.0.attn.wo_b.input_scale"),
|
||||
("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_a_norm.weight"),
|
||||
("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"),
|
||||
("layers.0.self_attn.sinks", "model.layers.0.attn.sinks"),
|
||||
|
||||
# Compressor (non-indexer): kv_proj → wkv, gate_proj → wgate
|
||||
("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.mla_attn.compressor.wkv.weight"),
|
||||
("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wkv.input_scale"),
|
||||
("layers.0.self_attn.compressor.kv_proj.weight_scale", "model.layers.0.attn.mla_attn.compressor.wkv.weight_scale"),
|
||||
("layers.0.self_attn.compressor.kv_proj.weight_scale_2", "model.layers.0.attn.mla_attn.compressor.wkv.weight_scale_2"),
|
||||
("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.mla_attn.compressor.wgate.weight"),
|
||||
("layers.0.self_attn.compressor.gate_proj.input_scale", "model.layers.0.attn.mla_attn.compressor.wgate.input_scale"),
|
||||
("layers.0.self_attn.compressor.gate_proj.weight_scale", "model.layers.0.attn.mla_attn.compressor.wgate.weight_scale"),
|
||||
("layers.0.self_attn.compressor.gate_proj.weight_scale_2", "model.layers.0.attn.mla_attn.compressor.wgate.weight_scale_2"),
|
||||
("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.mla_attn.compressor.norm.weight"),
|
||||
("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.mla_attn.compressor.ape"),
|
||||
|
||||
# Indexer own params
|
||||
("layers.10.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.10.attn.mla_attn.indexer.wq_b.weight"),
|
||||
("layers.10.self_attn.compressor.indexer.q_b_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.wq_b.input_scale"),
|
||||
("layers.10.self_attn.compressor.indexer.weights_proj.weight", "model.layers.10.attn.mla_attn.indexer.weights_proj.weight"),
|
||||
("layers.10.self_attn.compressor.indexer.weights_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.weights_proj.input_scale"),
|
||||
("layers.10.self_attn.compressor.indexer.kv_norm.weight", "model.layers.10.attn.mla_attn.indexer.k_norm.weight"),
|
||||
|
||||
# Indexer's compressor
|
||||
("layers.10.self_attn.compressor.indexer.kv_proj.weight", "model.layers.10.attn.mla_attn.indexer.compressor.wkv.weight"),
|
||||
("layers.10.self_attn.compressor.indexer.kv_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.compressor.wkv.input_scale"),
|
||||
("layers.10.self_attn.compressor.indexer.gate_proj.weight", "model.layers.10.attn.mla_attn.indexer.compressor.wgate.weight"),
|
||||
("layers.10.self_attn.compressor.indexer.gate_proj.input_scale", "model.layers.10.attn.mla_attn.indexer.compressor.wgate.input_scale"),
|
||||
("layers.10.self_attn.compressor.indexer.position_bias", "model.layers.10.attn.mla_attn.indexer.compressor.ape"),
|
||||
|
||||
# MoE gate
|
||||
("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.gate.tid2eid"),
|
||||
("layers.0.mlp.gate.weight", "model.layers.0.ffn.gate.weight"),
|
||||
|
||||
# Expert weights — gate_proj → w1, up_proj → w3, down_proj → w2
|
||||
("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"),
|
||||
("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"),
|
||||
("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"),
|
||||
("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"),
|
||||
("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"),
|
||||
("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"),
|
||||
|
||||
# Shared experts
|
||||
("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"),
|
||||
("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"),
|
||||
("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"),
|
||||
("layers.0.mlp.shared_experts.gate_proj.input_scale", "model.layers.0.ffn.shared_experts.w1.input_scale"),
|
||||
("layers.0.mlp.shared_experts.down_proj.weight_scale", "model.layers.0.ffn.shared_experts.down_proj.weight_scale"),
|
||||
|
||||
# Layer norm
|
||||
("layers.0.post_attention_layernorm.weight", "model.layers.0.post_attention_layernorm.weight"),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
mapper = _make_deepseek_v4_nvfp4_weights_mapper()
|
||||
def test_mapper():
|
||||
mapper = make_nvfp4_mapper()
|
||||
|
||||
test_cases = [
|
||||
# Attention projections
|
||||
("layers.0.self_attn.q_a_proj.weight", "model.layers.0.attn.wq_a.weight"),
|
||||
("layers.0.self_attn.q_a_proj.weight_scale", "model.layers.0.attn.wq_a.weight_scale"),
|
||||
("layers.0.self_attn.q_a_proj.weight_scale_2", "model.layers.0.attn.wq_a.weight_scale_2"),
|
||||
("layers.0.self_attn.q_a_proj.input_scale", "model.layers.0.attn.wq_a.input_scale"),
|
||||
("layers.0.self_attn.kv_proj.weight", "model.layers.0.attn.wkv.weight"),
|
||||
("layers.0.self_attn.q_b_proj.weight", "model.layers.0.attn.wq_b.weight"),
|
||||
("layers.0.self_attn.o_a_proj.weight", "model.layers.0.attn.wo_a.weight"),
|
||||
("layers.0.self_attn.o_b_proj.weight", "model.layers.0.attn.wo_b.weight"),
|
||||
("layers.0.self_attn.o_b_proj.weight_scale", "model.layers.0.attn.wo_b.weight_scale"),
|
||||
("layers.0.self_attn.q_a_norm.weight", "model.layers.0.attn.q_norm.weight"),
|
||||
("layers.0.self_attn.kv_norm.weight", "model.layers.0.attn.kv_norm.weight"),
|
||||
("layers.0.self_attn.sinks", "model.layers.0.attn.attn_sink"),
|
||||
|
||||
# Compressor (non-indexer)
|
||||
("layers.0.self_attn.compressor.kv_proj.weight", "model.layers.0.attn.compressor.wkv.weight"),
|
||||
("layers.0.self_attn.compressor.kv_proj.input_scale", "model.layers.0.attn.compressor.wkv.input_scale"),
|
||||
("layers.0.self_attn.compressor.gate_proj.weight", "model.layers.0.attn.compressor.wgate.weight"),
|
||||
("layers.0.self_attn.compressor.kv_norm.weight", "model.layers.0.attn.compressor.norm.weight"),
|
||||
("layers.0.self_attn.compressor.position_bias", "model.layers.0.attn.compressor.ape"),
|
||||
|
||||
# Indexer
|
||||
("layers.2.self_attn.compressor.indexer.q_b_proj.weight", "model.layers.2.attn.indexer.wq_b.weight"),
|
||||
("layers.2.self_attn.compressor.indexer.weights_proj.weight", "model.layers.2.attn.indexer.weights_proj.weight"),
|
||||
("layers.2.self_attn.compressor.indexer.kv_norm.weight", "model.layers.2.attn.indexer.k_norm.weight"),
|
||||
("layers.2.self_attn.compressor.indexer.kv_proj.weight", "model.layers.2.attn.indexer.compressor.wkv.weight"),
|
||||
("layers.2.self_attn.compressor.indexer.gate_proj.weight", "model.layers.2.attn.indexer.compressor.wgate.weight"),
|
||||
("layers.2.self_attn.compressor.indexer.position_bias", "model.layers.2.attn.indexer.compressor.ape"),
|
||||
|
||||
# Expert weights
|
||||
("layers.0.mlp.experts.0.gate_proj.weight", "model.layers.0.ffn.experts.0.w1.weight"),
|
||||
("layers.0.mlp.experts.0.gate_proj.weight_scale", "model.layers.0.ffn.experts.0.w1.weight_scale"),
|
||||
("layers.0.mlp.experts.0.gate_proj.weight_scale_2", "model.layers.0.ffn.experts.0.w1.weight_scale_2"),
|
||||
("layers.0.mlp.experts.0.gate_proj.input_scale", "model.layers.0.ffn.experts.0.w1.input_scale"),
|
||||
("layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.ffn.experts.0.w3.weight"),
|
||||
("layers.0.mlp.experts.0.down_proj.weight", "model.layers.0.ffn.experts.0.w2.weight"),
|
||||
|
||||
# Shared experts
|
||||
("layers.0.mlp.shared_experts.gate_proj.weight", "model.layers.0.ffn.shared_experts.w1.weight"),
|
||||
("layers.0.mlp.shared_experts.up_proj.weight", "model.layers.0.ffn.shared_experts.w3.weight"),
|
||||
("layers.0.mlp.shared_experts.down_proj.weight", "model.layers.0.ffn.shared_experts.down_proj.weight"),
|
||||
|
||||
# MoE gate + norm
|
||||
("layers.0.mlp.gate.weight", "model.layers.0.ffn.norm_gate.gate.weight"),
|
||||
("layers.0.mlp.gate.tid2eid", "model.layers.0.ffn.norm_gate.tid2eid"),
|
||||
("layers.0.input_layernorm.weight", "model.layers.0.attn_norm.weight"),
|
||||
("layers.0.post_attention_layernorm.weight", "model.layers.0.ffn.norm_gate.norm.weight"),
|
||||
|
||||
# HC params
|
||||
("layers.0.attn_hc.fn", "model.layers.0.hc_attn_fn"),
|
||||
("layers.0.ffn_hc.scale", "model.layers.0.hc_ffn_scale"),
|
||||
|
||||
# Global params
|
||||
("embed.weight", "model.embed_tokens.weight"),
|
||||
("norm.weight", "model.norm.weight"),
|
||||
("hc_head.fn", "model.hc_head_fn"),
|
||||
|
||||
# MTP (already uses ffn prefix in checkpoint)
|
||||
("mtp.0.ffn.experts.0.w1.weight", "model.mtp.0.ffn.experts.0.w1.weight"),
|
||||
("mtp.0.ffn_norm.weight", "model.mtp.0.ffn.norm_gate.norm.weight"),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for ckpt_key, expected in TEST_CASES:
|
||||
result = mapper.map_name(ckpt_key)
|
||||
for ckpt_key, expected in test_cases:
|
||||
result = mapper._map_name(ckpt_key)
|
||||
if result == expected:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
print(f"FAIL: {ckpt_key}")
|
||||
print(f" expected: {expected}")
|
||||
print(f" got: {result}")
|
||||
|
||||
print(f"\n{passed} passed, {failed} failed")
|
||||
return 0 if failed == 0 else 1
|
||||
print(f" Expected: {expected}")
|
||||
print(f" Got: {result}")
|
||||
failed += 1
|
||||
|
||||
print(f"\n{passed}/{passed+failed} tests passed")
|
||||
if failed > 0:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
test_mapper()
|
||||
|
||||
Reference in New Issue
Block a user