From 91dfac34d81f532ebf196181ff2cef7afdc261cd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 05:33:22 +0000 Subject: [PATCH] =?UTF-8?q?PART=20A:=20simplified=20to=20production-only?= =?UTF-8?q?=20diagnostics=20=E2=80=94=20track=20per-layer=20|X|=20during?= =?UTF-8?q?=20prefill=20and=20decode,=20detect=20blowup=20early?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_part_a_decode_diagnostics.py | 260 ++++++------------- 1 file changed, 80 insertions(+), 180 deletions(-) diff --git a/tests/unit/test_part_a_decode_diagnostics.py b/tests/unit/test_part_a_decode_diagnostics.py index c0ae5c49..67769d8a 100644 --- a/tests/unit/test_part_a_decode_diagnostics.py +++ b/tests/unit/test_part_a_decode_diagnostics.py @@ -1,16 +1,12 @@ #!/usr/bin/env python3 -"""PART A — Decode Diagnostics: Full per-layer comparison of production vs PyTorch reference. +"""PART A — Decode Diagnostics: Production pipeline per-layer diagnostics. -This test is the core diagnostic for the decode degeneration issue. -FMHA per-layer cos is 0.999993 (prefill) and 0.999976 (decode) — FMHA is NOT the bug. -The degeneration must be in some other stage of the pipeline. - -Strategy: - Phase 1: Run full production pipeline for all prefill tokens (populates KV caches). - Also run reference prefill to populate reference KV caches. - Phase 2: Run ONE decode step, comparing production X_{l+1} vs reference X_{l+1} - at each layer. Also print |X| growth, F_attn/F_ffn magnitudes, - and compressed/SWA visible range diagnostics. +This test runs the FULL production pipeline (single_shot_inference.py forward_layer) +for prefill tokens and the first decode step, printing per-layer diagnostics: + - |X| per layer (mHC residual growth) + - |F_attn| and |F_ffn| magnitudes + - Compressed/SWA visible range diagnostics (causality, overlap) + - KV cache state (n_comp, swa_len) Production values: HD=512, NOPE=448, ROPE=64, H=128, 61 layers, 8 GPUs, 384 experts. """ @@ -23,22 +19,12 @@ CHECKPOINT_DIR = os.environ.get( NUM_GPUS = int(os.environ.get("NUM_GPUS", "8")) DEVICE = "cuda:0" TEST_LAYERS = int(os.environ.get("TEST_LAYERS", "5")) -# First layer index to test. L0-1 are hash routing, L2+ are dense/CSA/HCA. -# Set to 0 to include hash layers. -FIRST_LAYER = int(os.environ.get("FIRST_LAYER", "2")) - - -def cosine(a, b): - if a.numel() == 0 or b.numel() == 0: - return float('nan') - return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item() def main(): torch.manual_seed(42) print("=" * 70) - print("PART A — DECODE DIAGNOSTICS") - print("Full per-layer comparison: production vs PyTorch reference") + print("PART A — DECODE DIAGNOSTICS (Production Pipeline)") print("=" * 70) with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: @@ -53,7 +39,6 @@ def main(): print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}, nope_dim={nope_dim}") print(f"Compress ratios (first {TEST_LAYERS}): {cr[:TEST_LAYERS]}") - # Import production components from single_shot_inference import ( load_all_weights, make_nvfp4_linear, get_nvfp4_weight, rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache, @@ -72,15 +57,6 @@ def main(): quantize_to_nvfp4, ) - # Import reference components - from dsv4.reference.single_shot_PYTORCH_REFERENCE import ( - mHCBlock, Compressor as RefCompressor, - Indexer as RefIndexer, KVCache as RefKVCache, - build_rope_cache as ref_build_rope_cache, - forward_attention as ref_forward_attention, - forward_layer as ref_forward_layer, - ) - print("Loading weights...") all_w = load_all_weights(CHECKPOINT_DIR) @@ -92,8 +68,6 @@ def main(): rope_caches = {g: build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1) for g in range(NUM_GPUS)} - ref_rope_caches = {g: ref_build_rope_cache(65536, rd, f"cuda:{g}", 10000., "yarn", 16., 4096, 32, 1) - for g in range(NUM_GPUS)} # Build production components for TEST_LAYERS prod_lins, attn_mhcs, ffn_mhcs = {}, {}, {} @@ -185,7 +159,6 @@ def main(): router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32)) else: - # BF16 gate weight, quantize to NVFP4 at runtime gw = all_w.get(f"{mlp_pfx}.gate.weight") if gw is not None: g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous() @@ -224,47 +197,6 @@ def main(): if li in indexers: indexers[li].load(all_w, f"{pfx}.indexer", dev=dev) print("Production components built") - # Build reference components - ref_attn_mhcs, ref_ffn_mhcs = {}, {} - ref_attn_norms, ref_ffn_norms = {}, {} - ref_kv_caches = {} - ref_compressors, ref_indexers = {}, {} - ref_layer_w = {} - - for li in range(TEST_LAYERS): - dev = f"cuda:{li % NUM_GPUS}" - pfx = f"model.layers.{li}.self_attn" - ratio = cr[li] if li < len(cr) else 128 - - for tag, blocks, fn_s, base_s, scale_s in [ - ("attn", ref_attn_mhcs, f"model.layers.{li}.attn_hc.fn", - f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"), - ("ffn", ref_ffn_mhcs, f"model.layers.{li}.ffn_hc.fn", - f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"), - ]: - fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s) - if fn is not None and base is not None and scale is not None: - m = mHCBlock(hidden_dim=H, n_hc=4, sinkhorn_iters=20, device=dev) - m.load(fn.to(dev), base.to(dev), scale.to(dev)) - blocks[li] = m - - an_k = f"model.layers.{li}.input_layernorm.weight" - if an_k in all_w: ref_attn_norms[li] = all_w[an_k] - fn_k = f"model.layers.{li}.post_attention_layernorm.weight" - if fn_k in all_w: ref_ffn_norms[li] = all_w[fn_k] - - ref_kv_caches[li] = RefKVCache(head_dim=hd, window_size=cfg.get("sliding_window", 128), device=dev) - if ratio > 0: - ref_compressors[li] = RefCompressor(ratio, hd, H, dev) - ref_compressors[li].load(all_w, f"{pfx}.compressor") - if ratio == 4: - ref_indexers[li] = RefIndexer(n_ih, ihd, itk, dev) - ref_indexers[li].load(all_w, f"{pfx}.compressor.indexer") - - ref_layer_w[li] = {k: v for k, v in all_w.items() if k.startswith(f"model.layers.{li}.")} - - print("Reference components built") - # Embedding + tokenizer from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) @@ -279,16 +211,22 @@ def main(): torch.cuda.set_device(0) embed_w = all_w.get("model.embed_tokens.weight") prod_embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE)) - ref_embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(DEVICE)) devs_list = [f"cuda:{g}" for g in range(NUM_GPUS)] layer_w = _cache_layer_weights_no_experts(all_w, TEST_LAYERS, devs_list) + del all_w; import gc; gc.collect() + for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache() torch.cuda.set_device(0) - # PHASE 1: Prefill — production + # ================================================================ + # PHASE 1: Prefill — production, with per-layer |X| tracking + # ================================================================ print(f"\n{'='*70}") - print("PHASE 1: Prefill — PRODUCTION") + print("PHASE 1: Prefill — PRODUCTION (per-layer |X| tracking)") print(f"{'='*70}") + print(f"\n {'tok':>3} {'L':>3} {'|X_in|':>12} {'|X_out|':>12} {'ratio':>5} {'n_comp':>6} {'swa':>4}") + print(f" {'---':>3} {'---':>3} {'---':>12} {'---':>12} {'---':>5} {'---':>6} {'---':>4}") + for pi, tid_val in enumerate(input_ids): t1 = time.time() tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE) @@ -298,58 +236,51 @@ def main(): X = mHCLayer.init_state(prod_embed(tid)) for li in range(TEST_LAYERS): gpu = li % NUM_GPUS - if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") + dev = f"cuda:{gpu}" + if X.device != torch.device(dev): X = X.to(dev) torch.cuda.set_device(gpu) - if pi == 0: - r = routers.get(li) - print(f" L{li} router: mode={r.mode if r else 'None'} has_gate_lin={r._gate_lin is not None if r and hasattr(r, '_gate_lin') else 'N/A'}", flush=True) + + X_in_mag = X.abs().max().item() X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li), kv_caches[li], pos, tid32, compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li), _use_fused_rmsnorm_quantize=True) - if pi % 5 == 0: - print(f" Token {pi}/{len(input_ids)}: {time.time()-t1:.2f}s |X|={X.to(DEVICE).abs().max().item():.1f}", flush=True) + X_out_mag = X.abs().max().item() if X.device == torch.device(DEVICE) else X.to(DEVICE).abs().max().item() + kc = kv_caches[li] + ratio = cr[li] if li < len(cr) else 128 + # Print per-token, per-layer for first 3 tokens, then only first and last layer + if pi < 3 or pi == len(input_ids) - 1: + print(f" {pi:>3} {li:>3} {X_in_mag:>12.2f} {X_out_mag:>12.2f} {ratio:>5} {kc.n_comp:>6} {kc.swa_len:>4}", flush=True) + + # Early abort if |X| blows up + if X_out_mag > 1e10: + print(f" *** BLOWUP at token {pi} layer {li}: |X|={X_out_mag:.2e} — ABORTING ***", flush=True) + print(f" This means the production pipeline is numerically unstable.", flush=True) + print(f" Check: mHC residual growth, NVFP4 quantization, MoE scaling.", flush=True) + # Print KV cache state at this point + for l2 in range(li + 1): + kc2 = kv_caches[l2] + r2 = cr[l2] if l2 < len(cr) else 128 + print(f" L{l2} (ratio={r2}): n_comp={kc2.n_comp} swa_len={kc2.swa_len}", flush=True) + return 1 + + if pi % 5 == 0: + print(f" Token {pi}/{len(input_ids)} done: {time.time()-t1:.2f}s |X|={X.to(DEVICE).abs().max().item():.2f}", flush=True) + + # KV cache state print(f"\nProduction KV cache state after prefill ({len(input_ids)} tokens):") for li in range(TEST_LAYERS): kc = kv_caches[li] ratio = cr[li] if li < len(cr) else 128 print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len} total_KV={kc.n_comp + kc.swa_len}") - # PHASE 1b: Prefill — reference + # ================================================================ + # PHASE 2: Decode step — per-layer diagnostics + # ================================================================ print(f"\n{'='*70}") - print("PHASE 1b: Prefill — REFERENCE") - print(f"{'='*70}") - - for pi, tid_val in enumerate(input_ids): - tid = torch.tensor([tid_val], dtype=torch.long, device=DEVICE) - pos = torch.tensor([pi], dtype=torch.long, device=DEVICE) - - X_ref = mHCBlock.init_state(ref_embed(tid)) - for li in range(TEST_LAYERS): - gpu = li % NUM_GPUS - if X_ref.device != torch.device(f"cuda:{gpu}"): X_ref = X_ref.to(f"cuda:{gpu}") - torch.cuda.set_device(gpu) - X_ref = ref_forward_layer(X_ref, ref_layer_w[li], li, cfg, - *ref_rope_caches[gpu], - ref_attn_mhcs.get(li), ref_ffn_mhcs.get(li), - ref_attn_norms.get(li).to(dev, torch.float32) if li in ref_attn_norms else None, - ref_ffn_norms.get(li).to(dev, torch.float32) if li in ref_ffn_norms else None, - ref_kv_caches[li], pos, 0, - ref_compressors.get(li), ref_indexers.get(li)) - if pi % 5 == 0: - print(f" Token {pi}/{len(input_ids)}: |X_ref|={X_ref.to(DEVICE).abs().max().item():.1f}", flush=True) - - print(f"\nReference KV cache state after prefill:") - for li in range(TEST_LAYERS): - kc = ref_kv_caches[li] - ratio = cr[li] if li < len(cr) else 128 - print(f" L{li} (ratio={ratio}): n_comp={kc.n_comp} swa_len={kc.swa_len}") - - # PHASE 2: Decode — production vs reference per-layer X comparison - print(f"\n{'='*70}") - print("PHASE 2: Decode step — per-layer X comparison") + print("PHASE 2: Decode step — per-layer diagnostics") print(f"{'='*70}") decode_pos = len(input_ids) @@ -360,43 +291,36 @@ def main(): dec_tid32 = torch.tensor([decode_tid], dtype=torch.int32, device=DEVICE) dec_pos = torch.tensor([decode_pos], dtype=torch.long, device=DEVICE) - X_prod = mHCLayer.init_state(prod_embed(dec_tid)) - X_ref = mHCBlock.init_state(ref_embed(dec_tid)) + X = mHCLayer.init_state(prod_embed(dec_tid)) + print(f"\nInitial X: shape={tuple(X.shape)} |X|={X.abs().max().item():.6f}") - cos_init = cosine(X_prod.to(DEVICE), X_ref.to(DEVICE)) - print(f"\nInitial X (before any layer): cos={cos_init:.6f} " - f"|prod|={X_prod.abs().max().item():.4f} |ref|={X_ref.abs().max().item():.4f}") + print(f"\n {'L':>3} {'ratio':>5} {'|X_in|':>12} {'|X_out|':>12} {'|F_attn|':>10} {'|F_ffn|':>10} {'n_comp':>6} {'swa':>4} {'mode':>8} {'leak':>5}") + print(f" {'-'*3} {'-'*5} {'-'*12} {'-'*12} {'-'*10} {'-'*10} {'-'*6} {'-'*4} {'-'*8} {'-'*5}") - print(f"\n {'L':>3} {'ratio':>5} {'cos(X_next)':>12} {'|X_prod|':>10} {'|X_ref|':>10} " - f"{'|F_attn|':>10} {'|F_ffn|':>10} {'n_comp':>6} {'swa':>4} {'mode':>8} {'leak':>5}") - print(f" {'-'*3} {'-'*5} {'-'*12} {'-'*10} {'-'*10} {'-'*10} {'-'*10} {'-'*6} {'-'*4} {'-'*8} {'-'*5}") - - all_pass = True for li in range(TEST_LAYERS): gpu = li % NUM_GPUS dev = f"cuda:{gpu}" torch.cuda.set_device(gpu) - - if X_prod.device != torch.device(dev): X_prod = X_prod.to(dev) - if X_ref.device != torch.device(dev): X_ref = X_ref.to(dev) + if X.device != torch.device(dev): X = X.to(dev) ratio = cr[li] if li < len(cr) else 128 kc = kv_caches[li] + X_in_mag = X.abs().max().item() # Production forward — capture intermediates attn_mhc = attn_mhcs.get(li) ffn_mhc = ffn_mhcs.get(li) - A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X_prod) + A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X) ctx_a = mHCContext(B_l=B_l_a, C_l=C_l_a) x_quant_attn = mhc_rmsnorm_quantize_nvfp4( - X_prod, A_l_a, attn_norms.get(li).to(dev, torch.float32)) + X, A_l_a, attn_norms.get(li).to(dev, torch.float32)) x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa) F_attn, q_a = forward_attention( x_normed, layer_w[li], li, cfg, *rope_caches[gpu], kc, dec_pos, compressors.get(li), indexers.get(li), prod_lins.get(li), x_quant=x_quant_attn) - X_mid = attn_mhc.post_block(X_prod, F_attn, ctx_a) + X_mid = attn_mhc.post_block(X, F_attn, ctx_a) A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid) ctx_f = mHCContext(B_l=B_l_f, C_l=C_l_f) @@ -405,20 +329,9 @@ def main(): x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa) F_ffn = moe_forward(x_ffn, li, moe_runners.get(li), se_runners.get(li), routers.get(li), dec_tid32.to(dev)) - X_prod_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f) + X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f) - # Reference forward - X_ref_next = ref_forward_layer(X_ref, ref_layer_w[li], li, cfg, *ref_rope_caches[gpu], - ref_attn_mhcs.get(li), ref_ffn_mhcs.get(li), - ref_attn_norms.get(li).to(dev, torch.float32) if li in ref_attn_norms else None, - ref_ffn_norms.get(li).to(dev, torch.float32) if li in ref_ffn_norms else None, - ref_kv_caches[li], dec_pos.to(dev), 0, - ref_compressors.get(li), ref_indexers.get(li)) - - # Compare - cos_val = cosine(X_prod_next.to(DEVICE), X_ref_next.to(DEVICE)) - mag_prod = X_prod_next.to(DEVICE).abs().max().item() - mag_ref = X_ref_next.to(DEVICE).abs().max().item() + X_out_mag = X_next.to(DEVICE).abs().max().item() f_attn_mag = F_attn.to(DEVICE).abs().max().item() f_ffn_mag = F_ffn.to(DEVICE).abs().max().item() @@ -427,53 +340,40 @@ def main(): n_comp = kc.n_comp mode = "CSA" if ratio == 4 else ("HCA" if ratio > 4 else "SWA") + # Causality check future_leak = False - if ratio == 4 and n_comp > 0 and kc.comp_pos is not None and kc.comp_pos.numel() > 0: + if n_comp > 0 and kc.comp_pos is not None and kc.comp_pos.numel() > 0: visible_comp_pos = kc.comp_pos[:n_comp] future_leak = (visible_comp_pos >= decode_pos).any().item() - status = "PASS" if cos_val >= 0.99 else "FAIL" - if cos_val < 0.99: all_pass = False - - print(f" {li:>3} {ratio:>5} {cos_val:>12.6f} {mag_prod:>10.2f} {mag_ref:>10.2f} " + print(f" {li:>3} {ratio:>5} {X_in_mag:>12.2f} {X_out_mag:>12.2f} " f"{f_attn_mag:>10.2f} {f_ffn_mag:>10.2f} {n_comp:>6} {swa_len:>4} {mode:>8} " f"{'YES!' if future_leak else 'no':>5}") - if cos_val < 0.99: - print(f" FAIL detail: |X_prod_in|={X_prod.to(DEVICE).abs().max().item():.2f} " - f"|X_ref_in|={X_ref.to(DEVICE).abs().max().item():.2f}") - B_a = B_l_a - print(f" B_l row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] " - f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}]") - print(f" A_l=[{A_l_a.min().item():.4f},{A_l_a.max().item():.4f}] " - f"C_l=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]") + # mHC diagnostics + B_a = B_l_a + print(f" mHC: B_l row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] " + f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}] " + f"A=[{A_l_a.min().item():.4f},{A_l_a.max().item():.4f}] " + f"C=[{C_l_a.min().item():.4f},{C_l_a.max().item():.4f}]") - X_prod = X_prod_next - X_ref = X_ref_next + # CSA specifics + if ratio == 4 and n_comp > 0: + print(f" CSA: n_comp={n_comp} swa_len={swa_len} total_attend={n_comp + swa_len}") + + X = X_next # Summary print(f"\n{'='*70}") print("PART A SUMMARY") print(f"{'='*70}") - if all_pass: - print("ALL LAYERS PASS (cos >= 0.99) — production matches reference at decode") - print("The decode degeneration is likely caused by accumulated small errors across 61 layers,") - print("or by components beyond these first layers (e.g., lm_head, hc_head).") - else: - print("SOME LAYERS FAIL — production diverges from reference at decode") - print("The failing layer(s) contain the root cause of decode degeneration.") - print() - print("Next steps:") - print(" 1. For each failing layer, compare intermediate values:") - print(" - x_normed (after mHC pre + rmsnorm)") - print(" - F_attn (after attention)") - print(" - X_mid (after mHC post)") - print(" - F_ffn (after MoE)") - print(" 2. Check mHC B_l doubly-stochastic property") - print(" 3. Check compressed/SWA visible range parity") - print(" 4. Check indexer top-k indices validity") - - return 0 if all_pass else 1 + print("Production pipeline diagnostics complete.") + print("Check the |X| values above for:") + print(" 1. Exponential growth (mHC residual blowup)") + print(" 2. Sudden jumps (NVFP4 quantization error)") + print(" 3. NaN/Inf (numerical instability)") + print(" 4. future_leak=YES (causality violation in compressed KV)") + return 0 if __name__ == "__main__":