Files
nvfp4-megamoe-kernel/tests/unit/test_mhc_comparison.py
biondizzle 8de47e26ce Cleanup Step 1: Move root-level files to proper directories
- Move test_*.py → tests/integration/
- Move probe_*.py, dump_*.py → helpers/
- Move PERFORMANCE_AUDIT.md → docs/
- Move single_shot_PYTORCH_REFERENCE.py → dsv4/reference/
- Fix 3 import references in test_layer_comparison, test_mhc_comparison, test_compressor_position_bias
- Add helpers/import_closure.py (dead-code detection tool)
2026-06-02 19:24:39 +00:00

170 lines
7.0 KiB
Python

#!/usr/bin/env python3
"""Focused comparison: production MoE vs PyTorch reference MoE at specific layers.
This test:
1. Loads both pipelines
2. Processes the same input token through 1 layer
3. Compares F_attn and F_ffn magnitudes between production and reference
4. Identifies where the magnitude diverges
"""
import os, sys, json, time, math, torch, torch.nn.functional as F
from pathlib import Path
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
DEVICE = "cuda:0"
HC_EPS = 1e-6
def sinkhorn_knopp(logits, t_max=20, eps=HC_EPS):
M = torch.softmax(logits, -1) + eps
M = M / (M.sum(-2, keepdim=True) + eps)
for _ in range(t_max - 1):
M = M / (M.sum(-1, keepdim=True) + eps)
M = M / (M.sum(-2, keepdim=True) + eps)
return M
def unweighted_rmsnorm(x, eps=1e-6):
x_f = x.float()
rms = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
return (x_f * rms).to(x.dtype)
def rmsnorm(x, w, eps=1e-6):
x_f = x.float()
rms = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
return (x_f * rms * w.float()).to(x.dtype)
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
O, I2 = weight.shape; I = I2 * 2
lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8)
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.)
hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.)
w = torch.stack([lo_f, hi_f], -1).reshape(O, I)
s = weight_scale.float().repeat_interleave(16, 1)
if weight_scale_2 is not None: s = s * weight_scale_2.float()
return (w * s).bfloat16()
def main():
torch.manual_seed(42)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
H = cfg["hidden_size"]
n_hc = cfg.get("n_hc", 4)
n_layers = cfg["num_hidden_layers"]
n_experts = cfg["n_routed_experts"]
top_k = cfg.get("num_experts_per_tok", 6)
intermediate = cfg.get("intermediate_size", 18432)
print(f"Model: {n_layers} layers, {H} hidden, {n_experts} experts, top-{top_k}")
# Load weights
print("Loading weights...")
from safetensors.torch import load_file
cdir = Path(CHECKPOINT_DIR); wmap = {}
idx = cdir / "model.safetensors.index.json"
if idx.exists():
with open(idx) as f: wmap = json.load(f).get("weight_map", {})
shards = set(wmap.values()) if wmap else set(); all_w = {}
for sn in sorted(shards):
if (cdir / sn).exists(): all_w.update(load_file(str(cdir / sn)))
print(f"Loaded {len(all_w)} tensors")
# Create a realistic hidden state (simulate running through a few layers)
# Use token embedding + a few layers of mHC
from dsv4.reference.single_shot_PYTORCH_REFERENCE import mHCBlock, load_weights as ref_load_weights, forward_layer
ref_all_w = ref_load_weights(CHECKPOINT_DIR)
# Build mHC blocks for first 3 layers
attn_mhcs, ffn_mhcs = [], []
attn_norms, ffn_norms = [], []
for li in range(min(5, n_layers)):
a_mhc = mHCBlock(H, n_hc, device=DEVICE)
a_mhc.load(ref_all_w[f"model.layers.{li}.attn_hc.fn"],
ref_all_w[f"model.layers.{li}.attn_hc.base"],
ref_all_w[f"model.layers.{li}.attn_hc.scale"])
attn_mhcs.append(a_mhc)
f_mhc = mHCBlock(H, n_hc, device=DEVICE)
f_mhc.load(ref_all_w[f"model.layers.{li}.ffn_hc.fn"],
ref_all_w[f"model.layers.{li}.ffn_hc.base"],
ref_all_w[f"model.layers.{li}.ffn_hc.scale"])
ffn_mhcs.append(f_mhc)
attn_norms.append(ref_all_w[f"model.layers.{li}.input_layernorm.weight"].bfloat16().to(DEVICE))
ffn_norms.append(ref_all_w[f"model.layers.{li}.post_attention_layernorm.weight"].bfloat16().to(DEVICE))
# Process one token through first 3 layers to get a realistic X state
emb_w = ref_all_w["model.embed_tokens.weight"]
emb = torch.nn.Embedding(emb_w.shape[0], emb_w.shape[1])
emb.weight.data = emb_w.bfloat16().to(DEVICE)
# "The" token
tid = 455
X = mHCBlock.init_state(emb(torch.tensor([tid], device=DEVICE)), n_hc=n_hc)
print(f"\nInitial |X| = {X.abs().max().item():.2f}")
# Run through first 3 layers using reference
kv_cache = {}
for li in range(3):
X = forward_layer(X, ref_all_w, li, cfg, None, None,
attn_mhcs[li], ffn_mhcs[li],
attn_norms[li], ffn_norms[li],
kv_cache, torch.tensor([3], device=DEVICE),
tid)
print(f" Ref L{li}: |X| = {X.abs().max().item():.2f}")
# Now X is a realistic hidden state after 3 layers
# Save it for both production and reference comparison
X_ref = X.clone()
X_prod = X.clone()
print(f"\nAfter 3 layers: |X| = {X_ref.abs().max().item():.2f}")
# --- Compare mHC at L3 ---
li = 3
print(f"\n=== Comparing mHC at L{li} ===")
# Reference mHC
a_mhc = attn_mhcs[3] # Already loaded
x_in_ref, ctx_ref = a_mhc.pre_block(X_ref)
print(f" Ref x_in: |x| = {x_in_ref.abs().max().item():.4f}")
print(f" Ref A: {ctx_ref['A'][0].tolist()}")
print(f" Ref C: {ctx_ref['C'][0].tolist()}")
print(f" Ref B row_sums: {ctx_ref['B'][0].sum(-1).tolist()}")
# Production mHC
from dsv4.layers.mhc import mHCLayer
prod_mhc = mHCLayer(hidden_dim=H, n_hc=n_hc, device=DEVICE)
# Load weights
fn = ref_all_w[f"model.layers.{li}.attn_hc.fn"].to(DEVICE, torch.float32)
base = ref_all_w[f"model.layers.{li}.attn_hc.base"].to(DEVICE)
scale = ref_all_w[f"model.layers.{li}.attn_hc.scale"].to(DEVICE)
n = n_hc
prod_mhc.load_weights(
W_pre=fn[0:n], W_post=fn[n:2*n], W_comb=fn[2*n:],
S_pre=base[0:n].reshape(1, n), S_post=base[n:2*n].reshape(n, 1),
S_comb=base[2*n:].reshape(n, n),
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item()
)
x_in_prod, ctx_prod = prod_mhc.pre_block(X_prod)
print(f" Prod x_in: |x| = {x_in_prod.abs().max().item():.4f}")
A_prod = ctx_prod.A_l
C_prod = ctx_prod.C_l
B_prod = ctx_prod.B_l
print(f" Prod A: {A_prod[0].tolist()}")
print(f" Prod C: {C_prod[0].tolist()}")
print(f" Prod B row_sums: {B_prod[0].sum(-1).tolist()}")
# Compare
cos_xin = F.cosine_similarity(x_in_ref.flatten().float(), x_in_prod.flatten().float(), dim=0).item()
cos_A = F.cosine_similarity(ctx_ref['A'].flatten().float(), A_prod.flatten().float(), dim=0).item()
cos_C = F.cosine_similarity(ctx_ref['C'].flatten().float(), C_prod.flatten().float(), dim=0).item()
cos_B = F.cosine_similarity(ctx_ref['B'].flatten().float(), B_prod.flatten().float(), dim=0).item()
print(f"\n cos(x_in): {cos_xin:.6f}")
print(f" cos(A): {cos_A:.6f}")
print(f" cos(C): {cos_C:.6f}")
print(f" cos(B): {cos_B:.6f}")
print("\nDone.")
if __name__ == "__main__":
main()