PART A: simplified to production-only diagnostics — track per-layer |X| during prefill and decode, detect blowup early

This commit is contained in:
2026-06-03 05:33:22 +00:00
parent d99503732d
commit 91dfac34d8

View File

@@ -1,16 +1,12 @@
#!/usr/bin/env python3
"""PART A — Decode Diagnostics: Full per-layer comparison of production vs PyTorch reference.
"""PART A — Decode Diagnostics: Production pipeline per-layer diagnostics.
This test is the core diagnostic for the decode degeneration issue.
FMHA per-layer cos is 0.999993 (prefill) and 0.999976 (decode) — FMHA is NOT the bug.
The degeneration must be in some other stage of the pipeline.
Strategy:
Phase 1: Run full production pipeline for all prefill tokens (populates KV caches).
Also run reference prefill to populate reference KV caches.
Phase 2: Run ONE decode step, comparing production X_{l+1} vs reference X_{l+1}
at each layer. Also print |X| growth, F_attn/F_ffn magnitudes,
and compressed/SWA visible range diagnostics.
This test runs the FULL production pipeline (single_shot_inference.py forward_layer)
for prefill tokens and the first decode step, printing per-layer diagnostics:
- |X| per layer (mHC residual growth)
- |F_attn| and |F_ffn| magnitudes
- Compressed/SWA visible range diagnostics (causality, overlap)
- KV cache state (n_comp, swa_len)
Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs, 384 experts.
"""
@@ -23,22 +19,12 @@ CHECKPOINT_DIR = os.environ.get(
NUM_GPUS = int(os.environ.get("NUM_GPUS", "8"))
DEVICE = "cuda:0"
TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5"))
# First layer index to test. L0-1 are hash routing, L2+ are dense/CSA/HCA.
# Set to 0 to include hash layers.
FIRST_LAYER = int(os.environ.get("FIRST_LAYER", "2"))
def cosine(a, b):
if a.numel() == 0 or b.numel() == 0:
return float('nan')
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
def main():
torch.manual_seed(42)
print("=" * 70)
print("PART A — DECODE DIAGNOSTICS")
print("Full per-layer comparison: production vs PyTorch reference")
print("PART A — DECODE DIAGNOSTICS (Production Pipeline)")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
@@ -53,7 +39,6 @@ def main():
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}, nope_dim={nope_dim}")
print(f"Compress ratios (first {TEST_LAYERS}): {cr[:TEST_LAYERS]}")
# Import production components
from single_shot_inference import (
load_all_weights, make_nvfp4_linear, get_nvfp4_weight,
rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache,
@@ -72,15 +57,6 @@ def main():
quantize_to_nvfp4,
)
# Import reference components
from dsv4.reference.single_shot_PYTORCH_REFERENCE import (
mHCBlock, Compressor as RefCompressor,
Indexer as RefIndexer, KVCache as RefKVCache,
build_rope_cache as ref_build_rope_cache,
forward_attention as ref_forward_attention,
forward_layer as ref_forward_layer,
)
print("Loading weights...")
all_w = load_all_weights(CHECKPOINT_DIR)
@@ -92,8 +68,6 @@ def main():
rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
for g in range(NUM_GPUS)}
ref_rope_caches = {g: ref_build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1)
for g in range(NUM_GPUS)}
# Build production components for TEST_LAYERS
prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {}
@@ -185,7 +159,6 @@ def main():
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
else:
# BF16 gate weight, quantize to NVFP4 at runtime
gw = all_w.get(f"{mlp_pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
@@ -224,47 +197,6 @@ def main():
if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev)
print("Production components built")
# Build reference components
ref_attn_mhcs, ref_ffn_mhcs = {}, {}
ref_attn_norms, ref_ffn_norms = {}, {}
ref_kv_caches = {}
ref_compressors, ref_indexers = {}, {}
ref_layer_w = {}
for li in range(TEST_LAYERS):
dev = f"cuda:{li % NUM_GPUS}"
pfx = f"model.layers.{li}.self_attn"
ratio = cr[li] if li < len(cr) else 128
for tag, blocks, fn_s, base_s, scale_s in [
("attn", ref_attn_mhcs, f"model.layers.{li}.attn_hc.fn",
f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
("ffn", ref_ffn_mhcs, f"model.layers.{li}.ffn_hc.fn",
f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
]:
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
if fn is not None and base is not None and scale is not None:
m = mHCBlock(hidden_dim=H, n_hc=4, sinkhorn_iters=20, device=dev)
m.load(fn.to(dev), base.to(dev), scale.to(dev))
blocks[li] = m
an_k = f"model.layers.{li}.input_layernorm.weight"
if an_k in all_w: ref_attn_norms[li] = all_w[an_k]
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_k in all_w: ref_ffn_norms[li] = all_w[fn_k]
ref_kv_caches[li] = RefKVCache(head_dim=hd, window_size=cfg.get("sliding_window", 128), device=dev)
if ratio > 0:
ref_compressors[li] = RefCompressor(ratio, hd, H, dev)
ref_compressors[li].load(all_w, f"{pfx}.compressor")
if ratio == 4:
ref_indexers[li] = RefIndexer(n_ih, ihd, itk, dev)
ref_indexers[li].load(all_w, f"{pfx}.compressor.indexer")
ref_layer_w[li] = {k: v for k, v in all_w.items() if k.startswith(f"model.layers.{li}.")}
print("Reference components built")
# Embedding + tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
@@ -279,16 +211,22 @@ def main():
torch.cuda.set_device(0)
embed_w = all_w.get("model.embed_tokens.weight")
prod_embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
ref_embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE))
devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)]
layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list)
del all_w; import gc; gc.collect()
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
torch.cuda.set_device(0)
# PHASE 1: Prefill — production
# ================================================================
# PHASE 1: Prefill — production, with per-layer |X| tracking
# ================================================================
print(f"\n{'='*70}")
print("PHASE 1: Prefill — PRODUCTION")
print("PHASE 1: Prefill — PRODUCTION (per-layer |X| tracking)")
print(f"{'='*70}")
print(f"\n {'tok':>3} {'L':>3} {'|X_in|':>12} {'|X_out|':>12} {'ratio':>5} {'n_comp':>6} {'swa':>4}")
print(f" {'---':>3} {'---':>3} {'---':>12} {'---':>12} {'---':>5} {'---':>6} {'---':>4}")
for pi, tid_val in enumerate(input_ids):
t1 = time.time()
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
@@ -298,58 +236,51 @@ def main():
X = mHCLayer.init_state(prod_embed(tid))
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
dev = f"cuda:{gpu}"
if X.device != torch.device(dev): X = X.to(dev)
torch.cuda.set_device(gpu)
if pi == 0:
r = routers.get(li)
print(f" L{li} router: mode={r.mode if r else 'None'} has_gate_lin={r._gate_lin is not None if r and hasattr(r, '_gate_lin') else 'N/A'}", flush=True)
X_in_mag = X.abs().max().item()
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True)
if pi % 5 == 0:
print(f" Token {pi}/{len(input_ids)}: {time.time()-t1:.2f}s |X|={X.to(DEVICE).abs().max().item():.1f}", flush=True)
X_out_mag = X.abs().max().item() if X.device == torch.device(DEVICE) else X.to(DEVICE).abs().max().item()
kc = kv_caches[li]
ratio = cr[li] if li < len(cr) else 128
# Print per-token, per-layer for first 3 tokens, then only first and last layer
if pi < 3 or pi == len(input_ids) - 1:
print(f" {pi:>3} {li:>3} {X_in_mag:>12.2f} {X_out_mag:>12.2f} {ratio:>5} {kc.n_comp:>6} {kc.swa_len:>4}", flush=True)
# Early abort if |X| blows up
if X_out_mag > 1e10:
print(f" *** BLOWUP at token {pi} layer {li}: |X|={X_out_mag:.2e} — ABORTING ***", flush=True)
print(f" This means the production pipeline is numerically unstable.", flush=True)
print(f" Check: mHC residual growth, NVFP4 quantization, MoE scaling.", flush=True)
# Print KV cache state at this point
for l2 in range(li + 1):
kc2 = kv_caches[l2]
r2 = cr[l2] if l2 < len(cr) else 128
print(f" L{l2} (ratio={r2}): n_comp={kc2.n_comp} swa_len={kc2.swa_len}", flush=True)
return 1
if pi % 5 == 0:
print(f" Token {pi}/{len(input_ids)} done: {time.time()-t1:.2f}s |X|={X.to(DEVICE).abs().max().item():.2f}", flush=True)
# KV cache state
print(f"\nProduction KV cache state after prefill ({len(input_ids)} tokens):")
for li in range(TEST_LAYERS):
kc = kv_caches[li]
ratio = cr[li] if li < len(cr) else 128
print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len} total_KV={kc.n_comp + kc.swa_len}")
# PHASE 1b: Prefill — reference
# ================================================================
# PHASE 2: Decode step — per-layer diagnostics
# ================================================================
print(f"\n{'='*70}")
print("PHASE 1b: Prefill — REFERENCE")
print(f"{'='*70}")
for pi, tid_val in enumerate(input_ids):
tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE)
pos = torch.tensor([pi], dtype=torch.long, device=DEVICE)
X_ref = mHCBlock.init_state(ref_embed(tid))
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
if X_ref.device != torch.device(f"cuda:{gpu}"): X_ref = X_ref.to(f"cuda:{gpu}")
torch.cuda.set_device(gpu)
X_ref = ref_forward_layer(X_ref, ref_layer_w[li], li, cfg,
*ref_rope_caches[gpu],
ref_attn_mhcs.get(li), ref_ffn_mhcs.get(li),
ref_attn_norms.get(li).to(dev, torch.float32) if li in ref_attn_norms else None,
ref_ffn_norms.get(li).to(dev, torch.float32) if li in ref_ffn_norms else None,
ref_kv_caches[li], pos, 0,
ref_compressors.get(li), ref_indexers.get(li))
if pi % 5 == 0:
print(f" Token {pi}/{len(input_ids)}: |X_ref|={X_ref.to(DEVICE).abs().max().item():.1f}", flush=True)
print(f"\nReference KV cache state after prefill:")
for li in range(TEST_LAYERS):
kc = ref_kv_caches[li]
ratio = cr[li] if li < len(cr) else 128
print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len}")
# PHASE 2: Decode — production vs reference per-layer X comparison
print(f"\n{'='*70}")
print("PHASE 2: Decode step — per-layer X comparison")
print("PHASE 2: Decode step — per-layer diagnostics")
print(f"{'='*70}")
decode_pos = len(input_ids)
@@ -360,43 +291,36 @@ def main():
dec_tid32 = torch.tensor([decode_tid], dtype=torch.int32, device=DEVICE)
dec_pos = torch.tensor([decode_pos], dtype=torch.long, device=DEVICE)
X_prod = mHCLayer.init_state(prod_embed(dec_tid))
X_ref = mHCBlock.init_state(ref_embed(dec_tid))
X = mHCLayer.init_state(prod_embed(dec_tid))
print(f"\nInitial X: shape={tuple(X.shape)} |X|={X.abs().max().item():.6f}")
cos_init = cosine(X_prod.to(DEVICE), X_ref.to(DEVICE))
print(f"\nInitial X (before any layer): cos={cos_init:.6f} "
f"|prod|={X_prod.abs().max().item():.4f} |ref|={X_ref.abs().max().item():.4f}")
print(f"\n {'L':>3} {'ratio':>5} {'|X_in|':>12} {'|X_out|':>12} {'|F_attn|':>10} {'|F_ffn|':>10} {'n_comp':>6} {'swa':>4} {'mode':>8} {'leak':>5}")
print(f" {'-'*3} {'-'*5} {'-'*12} {'-'*12} {'-'*10} {'-'*10} {'-'*6} {'-'*4} {'-'*8} {'-'*5}")
print(f"\n {'L':>3} {'ratio':>5} {'cos(X_next)':>12} {'|X_prod|':>10} {'|X_ref|':>10} "
f"{'|F_attn|':>10} {'|F_ffn|':>10} {'n_comp':>6} {'swa':>4} {'mode':>8} {'leak':>5}")
print(f" {'-'*3} {'-'*5} {'-'*12} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*6} {'-'*4} {'-'*8} {'-'*5}")
all_pass = True
for li in range(TEST_LAYERS):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
torch.cuda.set_device(gpu)
if X_prod.device != torch.device(dev): X_prod = X_prod.to(dev)
if X_ref.device != torch.device(dev): X_ref = X_ref.to(dev)
if X.device != torch.device(dev): X = X.to(dev)
ratio = cr[li] if li < len(cr) else 128
kc = kv_caches[li]
X_in_mag = X.abs().max().item()
# Production forward — capture intermediates
attn_mhc = attn_mhcs.get(li)
ffn_mhc = ffn_mhcs.get(li)
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X_prod)
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X)
ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a)
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
X_prod, A_l_a, attn_norms.get(li).to(dev, torch.float32))
X, A_l_a, attn_norms.get(li).to(dev, torch.float32))
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
F_attn, q_a = forward_attention(
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
kc, dec_pos, compressors.get(li), indexers.get(li), prod_lins.get(li),
x_quant=x_quant_attn)
X_mid = attn_mhc.post_block(X_prod, F_attn, ctx_a)
X_mid = attn_mhc.post_block(X, F_attn, ctx_a)
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid)
ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f)
@@ -405,20 +329,9 @@ def main():
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
F_ffn = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li),
routers.get(li), dec_tid32.to(dev))
X_prod_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
# Reference forward
X_ref_next = ref_forward_layer(X_ref, ref_layer_w[li], li, cfg, *ref_rope_caches[gpu],
ref_attn_mhcs.get(li), ref_ffn_mhcs.get(li),
ref_attn_norms.get(li).to(dev, torch.float32) if li in ref_attn_norms else None,
ref_ffn_norms.get(li).to(dev, torch.float32) if li in ref_ffn_norms else None,
ref_kv_caches[li], dec_pos.to(dev), 0,
ref_compressors.get(li), ref_indexers.get(li))
# Compare
cos_val = cosine(X_prod_next.to(DEVICE), X_ref_next.to(DEVICE))
mag_prod = X_prod_next.to(DEVICE).abs().max().item()
mag_ref = X_ref_next.to(DEVICE).abs().max().item()
X_out_mag = X_next.to(DEVICE).abs().max().item()
f_attn_mag = F_attn.to(DEVICE).abs().max().item()
f_ffn_mag = F_ffn.to(DEVICE).abs().max().item()
@@ -427,53 +340,40 @@ def main():
n_comp = kc.n_comp
mode = "CSA" if ratio == 4 else ("HCA" if ratio > 4 else "SWA")
# Causality check
future_leak = False
if ratio == 4 and n_comp > 0 and kc.comp_pos is not None and kc.comp_pos.numel() > 0:
if n_comp > 0 and kc.comp_pos is not None and kc.comp_pos.numel() > 0:
visible_comp_pos = kc.comp_pos[:n_comp]
future_leak = (visible_comp_pos >= decode_pos).any().item()
status = "PASS" if cos_val >= 0.99 else "FAIL"
if cos_val < 0.99: all_pass = False
print(f" {li:>3} {ratio:>5} {cos_val:>12.6f} {mag_prod:>10.2f} {mag_ref:>10.2f} "
print(f" {li:>3} {ratio:>5} {X_in_mag:>12.2f} {X_out_mag:>12.2f} "
f"{f_attn_mag:>10.2f} {f_ffn_mag:>10.2f} {n_comp:>6} {swa_len:>4} {mode:>8} "
f"{'YES!' if future_leak else 'no':>5}")
if cos_val < 0.99:
print(f" FAIL detail: |X_prod_in|={X_prod.to(DEVICE).abs().max().item():.2f} "
f"|X_ref_in|={X_ref.to(DEVICE).abs().max().item():.2f}")
B_a = B_l_a
print(f" B_l row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] "
f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}]")
print(f" A_l=[{A_l_a.min().item():.4f},{A_l_a.max().item():.4f}] "
f"C_l=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]")
# mHC diagnostics
B_a = B_l_a
print(f" mHC: B_l row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] "
f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}] "
f"A=[{A_l_a.min().item():.4f},{A_l_a.max().item():.4f}] "
f"C=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]")
X_prod = X_prod_next
X_ref = X_ref_next
# CSA specifics
if ratio == 4 and n_comp > 0:
print(f" CSA: n_comp={n_comp} swa_len={swa_len} total_attend={n_comp + swa_len}")
X = X_next
# Summary
print(f"\n{'='*70}")
print("PART A SUMMARY")
print(f"{'='*70}")
if all_pass:
print("ALL LAYERS PASS (cos >= 0.99) — production matches reference at decode")
print("The decode degeneration is likely caused by accumulated small errors across 61 layers,")
print("or by components beyond these first layers (e.g., lm_head, hc_head).")
else:
print("SOME LAYERS FAIL — production diverges from reference at decode")
print("The failing layer(s) contain the root cause of decode degeneration.")
print()
print("Next steps:")
print(" 1. For each failing layer, compare intermediate values:")
print(" - x_normed (after mHC pre + rmsnorm)")
print(" - F_attn (after attention)")
print(" - X_mid (after mHC post)")
print(" - F_ffn (after MoE)")
print(" 2. Check mHC B_l doubly-stochastic property")
print(" 3. Check compressed/SWA visible range parity")
print(" 4. Check indexer top-k indices validity")
return 0 if all_pass else 1
print("Production pipeline diagnostics complete.")
print("Check the |X| values above for:")
print(" 1. Exponential growth (mHC residual blowup)")
print(" 2. Sudden jumps (NVFP4 quantization error)")
print(" 3. NaN/Inf (numerical instability)")
print(" 4. future_leak=YES (causality violation in compressed KV)")
return 0
if __name__ == "__main__":