- Move test_*.py → tests/integration/ - Move probe_*.py, dump_*.py → helpers/ - Move PERFORMANCE_AUDIT.md → docs/ - Move single_shot_PYTORCH_REFERENCE.py → dsv4/reference/ - Fix 3 import references in test_layer_comparison, test_mhc_comparison, test_compressor_position_bias - Add helpers/import_closure.py (dead-code detection tool)
170 lines
7.0 KiB
Python
170 lines
7.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Focused comparison: production MoE vs PyTorch reference MoE at specific layers.
|
|
|
|
This test:
|
|
1. Loads both pipelines
|
|
2. Processes the same input token through 1 layer
|
|
3. Compares F_attn and F_ffn magnitudes between production and reference
|
|
4. Identifies where the magnitude diverges
|
|
"""
|
|
import os, sys, json, time, math, torch, torch.nn.functional as F
|
|
from pathlib import Path
|
|
|
|
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
|
DEVICE = "cuda:0"
|
|
HC_EPS = 1e-6
|
|
|
|
def sinkhorn_knopp(logits, t_max=20, eps=HC_EPS):
|
|
M = torch.softmax(logits, -1) + eps
|
|
M = M / (M.sum(-2, keepdim=True) + eps)
|
|
for _ in range(t_max - 1):
|
|
M = M / (M.sum(-1, keepdim=True) + eps)
|
|
M = M / (M.sum(-2, keepdim=True) + eps)
|
|
return M
|
|
|
|
def unweighted_rmsnorm(x, eps=1e-6):
|
|
x_f = x.float()
|
|
rms = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
|
return (x_f * rms).to(x.dtype)
|
|
|
|
def rmsnorm(x, w, eps=1e-6):
|
|
x_f = x.float()
|
|
rms = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
|
return (x_f * rms * w.float()).to(x.dtype)
|
|
|
|
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
|
|
|
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
|
|
O, I2 = weight.shape; I = I2 * 2
|
|
lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8)
|
|
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
|
lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.)
|
|
hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.)
|
|
w = torch.stack([lo_f, hi_f], -1).reshape(O, I)
|
|
s = weight_scale.float().repeat_interleave(16, 1)
|
|
if weight_scale_2 is not None: s = s * weight_scale_2.float()
|
|
return (w * s).bfloat16()
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
H = cfg["hidden_size"]
|
|
n_hc = cfg.get("n_hc", 4)
|
|
n_layers = cfg["num_hidden_layers"]
|
|
n_experts = cfg["n_routed_experts"]
|
|
top_k = cfg.get("num_experts_per_tok", 6)
|
|
intermediate = cfg.get("intermediate_size", 18432)
|
|
print(f"Model: {n_layers} layers, {H} hidden, {n_experts} experts, top-{top_k}")
|
|
|
|
# Load weights
|
|
print("Loading weights...")
|
|
from safetensors.torch import load_file
|
|
cdir = Path(CHECKPOINT_DIR); wmap = {}
|
|
idx = cdir / "model.safetensors.index.json"
|
|
if idx.exists():
|
|
with open(idx) as f: wmap = json.load(f).get("weight_map", {})
|
|
shards = set(wmap.values()) if wmap else set(); all_w = {}
|
|
for sn in sorted(shards):
|
|
if (cdir / sn).exists(): all_w.update(load_file(str(cdir / sn)))
|
|
print(f"Loaded {len(all_w)} tensors")
|
|
|
|
# Create a realistic hidden state (simulate running through a few layers)
|
|
# Use token embedding + a few layers of mHC
|
|
from dsv4.reference.single_shot_PYTORCH_REFERENCE import mHCBlock, load_weights as ref_load_weights, forward_layer
|
|
ref_all_w = ref_load_weights(CHECKPOINT_DIR)
|
|
|
|
# Build mHC blocks for first 3 layers
|
|
attn_mhcs, ffn_mhcs = [], []
|
|
attn_norms, ffn_norms = [], []
|
|
for li in range(min(5, n_layers)):
|
|
a_mhc = mHCBlock(H, n_hc, device=DEVICE)
|
|
a_mhc.load(ref_all_w[f"model.layers.{li}.attn_hc.fn"],
|
|
ref_all_w[f"model.layers.{li}.attn_hc.base"],
|
|
ref_all_w[f"model.layers.{li}.attn_hc.scale"])
|
|
attn_mhcs.append(a_mhc)
|
|
f_mhc = mHCBlock(H, n_hc, device=DEVICE)
|
|
f_mhc.load(ref_all_w[f"model.layers.{li}.ffn_hc.fn"],
|
|
ref_all_w[f"model.layers.{li}.ffn_hc.base"],
|
|
ref_all_w[f"model.layers.{li}.ffn_hc.scale"])
|
|
ffn_mhcs.append(f_mhc)
|
|
attn_norms.append(ref_all_w[f"model.layers.{li}.input_layernorm.weight"].bfloat16().to(DEVICE))
|
|
ffn_norms.append(ref_all_w[f"model.layers.{li}.post_attention_layernorm.weight"].bfloat16().to(DEVICE))
|
|
|
|
# Process one token through first 3 layers to get a realistic X state
|
|
emb_w = ref_all_w["model.embed_tokens.weight"]
|
|
emb = torch.nn.Embedding(emb_w.shape[0], emb_w.shape[1])
|
|
emb.weight.data = emb_w.bfloat16().to(DEVICE)
|
|
|
|
# "The" token
|
|
tid = 455
|
|
X = mHCBlock.init_state(emb(torch.tensor([tid], device=DEVICE)), n_hc=n_hc)
|
|
print(f"\nInitial |X| = {X.abs().max().item():.2f}")
|
|
|
|
# Run through first 3 layers using reference
|
|
kv_cache = {}
|
|
for li in range(3):
|
|
X = forward_layer(X, ref_all_w, li, cfg, None, None,
|
|
attn_mhcs[li], ffn_mhcs[li],
|
|
attn_norms[li], ffn_norms[li],
|
|
kv_cache, torch.tensor([3], device=DEVICE),
|
|
tid)
|
|
print(f" Ref L{li}: |X| = {X.abs().max().item():.2f}")
|
|
|
|
# Now X is a realistic hidden state after 3 layers
|
|
# Save it for both production and reference comparison
|
|
X_ref = X.clone()
|
|
X_prod = X.clone()
|
|
print(f"\nAfter 3 layers: |X| = {X_ref.abs().max().item():.2f}")
|
|
|
|
# --- Compare mHC at L3 ---
|
|
li = 3
|
|
print(f"\n=== Comparing mHC at L{li} ===")
|
|
|
|
# Reference mHC
|
|
a_mhc = attn_mhcs[3] # Already loaded
|
|
x_in_ref, ctx_ref = a_mhc.pre_block(X_ref)
|
|
print(f" Ref x_in: |x| = {x_in_ref.abs().max().item():.4f}")
|
|
print(f" Ref A: {ctx_ref['A'][0].tolist()}")
|
|
print(f" Ref C: {ctx_ref['C'][0].tolist()}")
|
|
print(f" Ref B row_sums: {ctx_ref['B'][0].sum(-1).tolist()}")
|
|
|
|
# Production mHC
|
|
from dsv4.layers.mhc import mHCLayer
|
|
prod_mhc = mHCLayer(hidden_dim=H, n_hc=n_hc, device=DEVICE)
|
|
# Load weights
|
|
fn = ref_all_w[f"model.layers.{li}.attn_hc.fn"].to(DEVICE, torch.float32)
|
|
base = ref_all_w[f"model.layers.{li}.attn_hc.base"].to(DEVICE)
|
|
scale = ref_all_w[f"model.layers.{li}.attn_hc.scale"].to(DEVICE)
|
|
n = n_hc
|
|
prod_mhc.load_weights(
|
|
W_pre=fn[0:n], W_post=fn[n:2*n], W_comb=fn[2*n:],
|
|
S_pre=base[0:n].reshape(1, n), S_post=base[n:2*n].reshape(n, 1),
|
|
S_comb=base[2*n:].reshape(n, n),
|
|
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item()
|
|
)
|
|
x_in_prod, ctx_prod = prod_mhc.pre_block(X_prod)
|
|
print(f" Prod x_in: |x| = {x_in_prod.abs().max().item():.4f}")
|
|
A_prod = ctx_prod.A_l
|
|
C_prod = ctx_prod.C_l
|
|
B_prod = ctx_prod.B_l
|
|
print(f" Prod A: {A_prod[0].tolist()}")
|
|
print(f" Prod C: {C_prod[0].tolist()}")
|
|
print(f" Prod B row_sums: {B_prod[0].sum(-1).tolist()}")
|
|
|
|
# Compare
|
|
cos_xin = F.cosine_similarity(x_in_ref.flatten().float(), x_in_prod.flatten().float(), dim=0).item()
|
|
cos_A = F.cosine_similarity(ctx_ref['A'].flatten().float(), A_prod.flatten().float(), dim=0).item()
|
|
cos_C = F.cosine_similarity(ctx_ref['C'].flatten().float(), C_prod.flatten().float(), dim=0).item()
|
|
cos_B = F.cosine_similarity(ctx_ref['B'].flatten().float(), B_prod.flatten().float(), dim=0).item()
|
|
print(f"\n cos(x_in): {cos_xin:.6f}")
|
|
print(f" cos(A): {cos_A:.6f}")
|
|
print(f" cos(C): {cos_C:.6f}")
|
|
print(f" cos(B): {cos_B:.6f}")
|
|
|
|
print("\nDone.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|