diff --git a/tests/unit/test_degeneration_2_mhc_falsify.py b/tests/unit/test_degeneration_2_mhc_falsify.py index c7d7f918..d0ea8141 100644 --- a/tests/unit/test_degeneration_2_mhc_falsify.py +++ b/tests/unit/test_degeneration_2_mhc_falsify.py @@ -1,64 +1,60 @@ #!/usr/bin/env python3 """DEGENERATION TEST 2 — Falsify the mHC "root cause". -Claim under test: "|X|=860 compresses the logit range so the model can't distinguish tokens." +Claim: "|X|=860 compresses the logit range so the model can't distinguish tokens." +Test: RMSNorm is scale-invariant, so |X|=860 and |X|=8 should give the same logits. + If they differ, the final norm is missing/broken, NOT mHC. -Why it's suspect: there is a final RMSNorm before the LM head, and RMSNorm is -scale-invariant — it divides the magnitude out. So |X|=860 and |X|=8 should produce -the SAME logits (modulo the learned norm weight). Also, the residual grows just as -much during prefill yet prefill/first-token is correct — magnitude common to both -phases cannot be what breaks only decode. - -Procedure: -1. Confirm the final norm exists and is applied. -2. Falsification: compute logits with X as-is (|X|≈860) and X/100, compare. - If argmax matches and cos≈1.0 → mHC growth is EXONERATED. - If they differ → something downstream is magnitude-sensitive → norm is missing/broken. - -This test loads the FULL model (61 layers, 8 GPUs, production values). -It runs one decode step and captures the final-layer residual for the comparison. +This test runs single_shot_inference.py with a monkey-patch that intercepts +the final-layer residual and does the scale-invariance comparison. """ -import os, sys, time, json, math -import torch -import torch.nn.functional as F - -# This test imports and reuses single_shot_inference.py's infrastructure -# but intercepts at the hc_head / final_norm / lm_head stage. +import os, sys, time CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") -NUM_GPUS = 8 def main(): - print("=" * 70) - print("DEGENERATION TEST 2 — Falsify mHC residual growth root cause") - print("=" * 70) + import torch + import torch.nn.functional as F + from transformers import AutoTokenizer - # We need to run the full pipeline to get the final-layer residual X. - # The simplest approach: import single_shot's main, but intercept at the lm_head. - # Instead of re-implementing everything, we'll modify the decode loop to capture X - # and do the comparison after one decode step. + # We'll import single_shot and monkey-patch the decode loop to capture X + # after all layers and before hc_head/final_norm/lm_head. + # Then we do the scale-invariance test on the captured X. - # Load the model using single_shot's infrastructure sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + # Load everything through single_shot's infrastructure + # Strategy: import single_shot, call its setup functions, then do our own decode + # with interception at the hc_head point. + + import json + from pathlib import Path from single_shot_inference import ( load_all_weights, build_rope_cache, rmsnorm, unweighted_rmsnorm, - mHCLayer, HcHead, KVCache, Compressor, Indexer, - make_nvfp4_linear, get_nvfp4_weight, do_nvfp4_linear_ref, - forward_layer, moe_forward, _cache_layer_weights_no_experts, + HcHead, KVCache, Compressor, Indexer, + make_nvfp4_linear, get_nvfp4_weight, + forward_layer, moe_forward, + _cache_layer_weights_no_experts, _load_moe_weights_stacked, _load_shared_expert_weights, FP4_LUT, HC_EPS, THINK_START, THINK_END, USER_TOKEN, ASSISTANT_TOKEN, kill_stale_gpu_processes, ) - - from transformers import AutoTokenizer - from dsv4.layers.mhc import mHCLayer as mHCLayerProd + from dsv4.layers.mhc import mHCLayer from dsv4.layers.router import Router - from dsv4.layers.moe import Nvfp4MoE - from dsv4.layers.shared_expert import Nvfp4SharedExpert from dsv4.layers.linear import Nvfp4Linear from dsv4.layers.grouped_linear import Nvfp4GroupedLinear + from dsv4.layers.moe import Nvfp4MoE + from dsv4.layers.shared_expert import Nvfp4SharedExpert from dsv4.ops.quantize import quantize_weight_to_nvfp4 + NUM_GPUS = 8 + PROMPT = "The capital of France is" + HIDDEN = 7168 + + print("=" * 70) + print("DEGENERATION TEST 2 — Falsify mHC residual growth root cause") + print("=" * 70) + t0 = time.time(); torch.manual_seed(42) with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: @@ -67,19 +63,15 @@ def main(): hd = cfg["head_dim"]; n_h = cfg["num_attention_heads"] rd = cfg.get("qk_rope_head_dim", 64) cr = cfg.get("compress_ratios", [128] * n_layers) - PROMPT = "The capital of France is" - - print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}") + print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}") # Load weights print(f"\nLoading weights..."); all_w = load_all_weights(CHECKPOINT_DIR) - - # Build production components (same as single_shot main) kill_stale_gpu_processes() for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache() torch.cuda.set_device(0) - # mHC + norms + # Build mHC + norms attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {} for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}" @@ -89,7 +81,7 @@ def main(): ]: fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s) if fn is not None and base is not None and scale is not None: - m = mHCLayerProd(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev) + m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev) n = 4 m.load_weights( W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32), @@ -127,8 +119,7 @@ def main(): else: oa_bf = all_w.get(f"{pfx}.o_a_proj.weight") if oa_bf is not None: wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev)) - pl['o_a'] = wo_a - wo_a._use_runtime_gsa = True + pl['o_a'] = wo_a; wo_a._use_runtime_gsa = True pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj') prod_lins[li] = pl @@ -136,7 +127,7 @@ def main(): routers, moe_runners, se_runners = {}, {}, {} for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.mlp" - torch.cuda.set_device(li % NUM_GPUS); torch.cuda.synchronize() + torch.cuda.set_device(li % NUM_GPUS) is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{pfx}.gate.tid2eid" in all_w) router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"], top_k=cfg.get("num_experts_per_tok", 6), @@ -152,51 +143,39 @@ def main(): if gate_w is not None and gate_ws is not None: gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev) gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev) - gate_lin.fp4 = [gate_w_view] - gate_lin.sf = [gate_ws.to(dev)] + gate_lin.fp4 = [gate_w_view]; gate_lin.sf = [gate_ws.to(dev)] ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0 isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0) - gate_lin.gs = [1.0] - gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = isc_v - gate_lin._use_runtime_gsa = True - gate_lin.finalize_weights() - router.load_nvfp4_gate(gate_lin) + gate_lin.gs = [1.0]; gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)] + gate_lin._activation_global_scale = isc_v; gate_lin._use_runtime_gsa = True + gate_lin.finalize_weights(); router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32)) else: gw = all_w.get(f"{pfx}.gate.weight") if gw is not None: g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous() g_bf16 = g_bf16.bfloat16().to(dev) - from dsv4.ops.quantize import quantize_to_nvfp4 g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16) gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev) gate_lin.fp4 = [g_fp4]; gate_lin.sf = [g_sf]; gate_lin.gs = [g_gs] gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) - gate_lin._use_runtime_gsa = True - gate_lin.finalize_weights() - router.load_nvfp4_gate(gate_lin) + gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0); gate_lin._use_runtime_gsa = True + gate_lin.finalize_weights(); router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32)) router.finalize_weights(); routers[li] = router moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), top_k=cfg.get("num_experts_per_tok", 6), device=dev) - moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)) - moe.set_fused_swiglu(True) + moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)); moe.set_fused_swiglu(True) _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg) - moe._ensure_stacked() - moe._use_runtime_gsa = True - moe_runners[li] = moe + moe._ensure_stacked(); moe._use_runtime_gsa = True; moe_runners[li] = moe se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0)) se.set_fused_swiglu(True) _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg) - se._ensure_initialized() - se._use_runtime_gsa = True - se_runners[li] = se + se._ensure_initialized(); se._use_runtime_gsa = True; se_runners[li] = se if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers") torch.cuda.empty_cache() @@ -204,30 +183,23 @@ def main(): torch.cuda.set_device(0) embed_w = all_w.get("model.embed_tokens.weight") embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0')) - lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0') lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0') lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous()) lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()] lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()] - lm_head_lin.gs = [lm_gs] - lm_head_lin.ws2 = [None] + lm_head_lin.gs = [lm_gs]; lm_head_lin.ws2 = [None] lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0) - lm_head_lin._use_runtime_gsa = True - lm_head_lin.finalize_weights() + lm_head_lin._use_runtime_gsa = True; lm_head_lin.finalize_weights() final_norm_w = all_w.get("model.norm.weight") - if final_norm_w is not None: - final_norm_w = final_norm_w.to('cuda:0', torch.float32) + if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32) hc_head = HcHead(H, 4, 'cuda:0') - hc_fn = all_w.get("model.hc_head.hc_fn") - hc_base = all_w.get("model.hc_head.hc_base") + hc_fn = all_w.get("model.hc_head.hc_fn"); hc_base = all_w.get("model.hc_head.hc_base") hc_scale = all_w.get("model.hc_head.hc_scale") - if hc_fn is not None and hc_base is not None: - hc_head.load(hc_fn, hc_base, hc_scale) + if hc_fn is not None and hc_base is not None: hc_head.load(hc_fn, hc_base, hc_scale) - # RoPE rp = cfg.get("rope_scaling", cfg.get("rope_parameters", {})) rt = rp.get("type", rp.get("rope_type", "yarn")); rf = rp.get("factor", 16.0) rtheta = cfg.get("rope_theta", 10000.) @@ -236,7 +208,6 @@ def main(): rope_caches = {g: build_rope_cache(romax, rd, f"cuda:{g}", rtheta, rt, rf, romax, rbfast, rbslow) for g in range(NUM_GPUS)} - # KV caches, compressors, indexers kv_caches, compressors, indexers = {}, {}, {} n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128) itk = cfg.get("index_topk", 1024) @@ -248,7 +219,6 @@ def main(): if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev) - # Cache layer weights devs = [f"cuda:{g}" for g in range(NUM_GPUS)] layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs) del all_w; import gc; gc.collect() @@ -262,15 +232,13 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) bos = tokenizer.bos_token_id or 0 + # FIXED: no \n\n (official DSV4 encoding spec) input_ids = [bos, USER_TOKEN] - input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) + input_ids += tokenizer.encode(PROMPT, add_special_tokens=False) input_ids.append(ASSISTANT_TOKEN) input_ids.append(THINK_START) - print(f"\nPhase: Prefill + 1 decode step") - print(f" Input: {len(input_ids)} tokens") - - # Prefill + print(f"\nPrefill + 1 decode step...") PREFILL_CHUNK = 128 n_prefill = len(input_ids) prefill_ids = torch.tensor(input_ids, dtype=torch.long, device='cuda:0') @@ -281,12 +249,10 @@ def main(): X = None for ci, cs in enumerate(chunk_starts): ce = min(cs + PREFILL_CHUNK, n_prefill) - chunk_len = ce - cs chunk_ids = prefill_ids[cs:ce] chunk_ids32 = prefill_ids32[cs:ce] chunk_positions = all_positions[cs:ce] - chunk_embed = embed(chunk_ids) - X = mHCLayerProd.init_state(chunk_embed) + X = mHCLayer.init_state(embed(chunk_ids)) for li in range(n_layers): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") @@ -299,14 +265,14 @@ def main(): moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li)) X = X.to('cuda:0'); torch.cuda.set_device(0) - print(f" Chunk {ci+1}/{len(chunk_starts)}: OK", flush=True) + print(f" Chunk {ci+1}/{len(chunk_starts)}: OK |X|={X.abs().max().item():.1f}", flush=True) # Decode step 1 dec_tid = torch.tensor([input_ids[-1]], dtype=torch.long, device='cuda:0') dec_tid32 = dec_tid.to(torch.int32) dec_pos = torch.tensor([n_prefill - 1], dtype=torch.long, device='cuda:0') - X = mHCLayerProd.init_state(embed(dec_tid)) + X = mHCLayer.init_state(embed(dec_tid)) for li in range(n_layers): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") @@ -328,76 +294,64 @@ def main(): print("TEST 2 — Falsify mHC residual growth root cause") print(f"{'='*70}") - # Step 1: Confirm final norm exists and is applied + # Step 1: Confirm final norm exists print(f"\n1. FINAL NORM CHECK:") print(f" final_norm_w exists: {final_norm_w is not None}") if final_norm_w is not None: print(f" final_norm_w shape: {final_norm_w.shape}, dtype: {final_norm_w.dtype}") print(f" final_norm_w range: [{final_norm_w.min().item():.6f}, {final_norm_w.max().item():.6f}]") else: - print(f" *** CRITICAL: final_norm_w is MISSING! This is likely the real bug! ***") + print(f" *** CRITICAL: final_norm_w is MISSING! ***") - # Step 2: Trace the full path: X → hc_head → final_norm → lm_head → logits + # Step 2: Residual inspection print(f"\n2. RESIDUAL INSPECTION:") - X_abs_max = X.abs().max().item() - print(f" |X| (final layer residual) = {X_abs_max:.4f}") + X_max = X.abs().max().item() + print(f" |X| (final layer residual) = {X_max:.4f}") print(f" X shape: {X.shape}, dtype: {X.dtype}") - # hc_head: takes (T, n_hc, d) → (T, d) + # Step 3: Trace full path X → hc_head → final_norm → lm_head → logits x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :] - x_out_max = x_out.abs().max().item() - print(f" |x_out| (after hc_head) = {x_out_max:.4f}") - print(f" x_out shape: {x_out.shape}, dtype: {x_out.dtype}") + print(f" |x_out| (after hc_head) = {x_out.abs().max().item():.4f}") - # Apply final norm if final_norm_w is not None: x_normed = rmsnorm(x_out, final_norm_w) - x_normed_max = x_normed.abs().max().item() - print(f" |x_normed| (after final_norm) = {x_normed_max:.4f}") - # Verify scale invariance: rmsnorm should divide out magnitude + print(f" |x_normed| (after final_norm) = {x_normed.abs().max().item():.4f}") + # Verify scale invariance of RMSNorm alone x_out_tiny = x_out / 100.0 x_normed_tiny = rmsnorm(x_out_tiny, final_norm_w) cos_norm = F.cosine_similarity(x_normed.flatten().float(), x_normed_tiny.flatten().float(), dim=0).item() print(f" RMSNorm scale invariance: cos(x_normed, x_normed_tiny) = {cos_norm:.8f}") - print(f" (Expected: 1.0 — RMSNorm is scale-invariant)") else: x_normed = x_out - print(f" *** NO FINAL NORM APPLIED — logits will be magnitude-dependent! ***") + print(f" *** NO FINAL NORM — logits will be magnitude-dependent! ***") - # Step 3: Falsification — compute logits with X and X/100 - print(f"\n3. FALSIFICATION: logits with |X|={X_abs_max:.1f} vs |X/100|={X_abs_max/100:.1f}") + # Step 4: FALSIFICATION — logits with X vs X/100 + print(f"\n3. FALSIFICATION: logits with |X|={X_max:.1f} vs |X/100|={X_max/100:.2f}") - # Path A: logits with X as-is + # Path A: X as-is x_out_A = hc_head.forward(X) if hc_head is not None else X[:, 0, :] - if final_norm_w is not None: - x_out_A = rmsnorm(x_out_A, final_norm_w) + if final_norm_w is not None: x_out_A = rmsnorm(x_out_A, final_norm_w) logits_A = lm_head_lin(x_out_A) - # Path B: logits with X scaled down by 100 + # Path B: X scaled down by 100 X_scaled = X / 100.0 x_out_B = hc_head.forward(X_scaled) if hc_head is not None else X_scaled[:, 0, :] - if final_norm_w is not None: - x_out_B = rmsnorm(x_out_B, final_norm_w) + if final_norm_w is not None: x_out_B = rmsnorm(x_out_B, final_norm_w) logits_B = lm_head_lin(x_out_B) torch.cuda.synchronize() - - logits_A_f = logits_A.float() - logits_B_f = logits_B.float() - - argmax_A = logits_A_f.argmax().item() - argmax_B = logits_B_f.argmax().item() + logits_A_f = logits_A.float(); logits_B_f = logits_B.float() + argmax_A = logits_A_f.argmax().item(); argmax_B = logits_B_f.argmax().item() cos_AB = F.cosine_similarity(logits_A_f.flatten(), logits_B_f.flatten(), dim=0).item() - top5_A_vals, top5_A_ids = logits_A_f.topk(5) top5_B_vals, top5_B_ids = logits_B_f.topk(5) - print(f"\n logits_A (|X|={X_abs_max:.1f}):") + print(f"\n logits_A (|X|={X_max:.1f}):") print(f" range: [{logits_A_f.min().item():.2f}, {logits_A_f.max().item():.2f}]") print(f" argmax: {argmax_A} ('{tokenizer.decode([argmax_A])}')") print(f" top-5: {[(tokenizer.decode([tid.item()]), f'{val.item():.2f}') for tid, val in zip(top5_A_ids, top5_A_vals)]}") - print(f"\n logits_B (|X/100|={X_abs_max/100:.2f}):") + print(f"\n logits_B (|X/100|={X_max/100:.2f}):") print(f" range: [{logits_B_f.min().item():.2f}, {logits_B_f.max().item():.2f}]") print(f" argmax: {argmax_B} ('{tokenizer.decode([argmax_B])}')") print(f" top-5: {[(tokenizer.decode([tid.item()]), f'{val.item():.2f}') for tid, val in zip(top5_B_ids, top5_B_vals)]}") @@ -405,45 +359,36 @@ def main(): print(f"\n cos(logits_A, logits_B) = {cos_AB:.8f}") print(f" argmax_A == argmax_B: {argmax_A == argmax_B}") - # Step 4: Also check hc_head behavior — is it magnitude-sensitive? + # Step 5: hc_head magnitude sensitivity print(f"\n4. HC_HEAD MAGNITUDE SENSITIVITY:") - # hc_head does: rmsnorm(X) → linear → sigmoid → sum * X - # The sigmoid step makes it potentially magnitude-sensitive - # Let's check: does hc_head(X) scale linearly with |X|? x_out_A_raw = hc_head.forward(X) if hc_head is not None else X[:, 0, :] x_out_B_raw = hc_head.forward(X / 100.0) if hc_head is not None else (X / 100.0)[:, 0, :] cos_hc = F.cosine_similarity(x_out_A_raw.flatten().float(), (x_out_B_raw * 100.0).flatten().float(), dim=0).item() print(f" cos(hc_head(X), hc_head(X/100)*100) = {cos_hc:.8f}") - print(f" (If hc_head is NOT magnitude-sensitive, this should be 1.0)") print(f" |hc_head(X)| = {x_out_A_raw.abs().max().item():.4f}") print(f" |hc_head(X/100)| = {x_out_B_raw.abs().max().item():.6f}") print(f" |hc_head(X/100)*100| = {(x_out_B_raw * 100.0).abs().max().item():.4f}") - # Step 5: Final verdict + # Step 6: Verdict print(f"\n{'='*70}") print("VERDICT:") print(f"{'='*70}") if final_norm_w is None: print(" *** CRITICAL: FINAL NORM IS MISSING! ***") print(" The model has no RMSNorm before the LM head.") - print(" This means logits are magnitude-dependent → mHC residual growth IS the problem.") print(" FIX: Apply the final norm before lm_head.") elif cos_AB >= 0.999: print(" mHC residual growth is EXONERATED.") print(f" cos(logits_A, logits_B) = {cos_AB:.8f} ≈ 1.0") print(f" argmax_A={argmax_A}, argmax_B={argmax_B}") print(" |X| magnitude does NOT affect logits (RMSNorm divides it out).") - print(" The degeneration cause is elsewhere — likely Test 1 (chat template).") + print(" The degeneration cause is elsewhere — likely the prompt format (Test 1).") elif argmax_A != argmax_B: print(" mHC residual growth IS magnitude-sensitive despite final norm.") - print(f" argmax_A={argmax_A} ≠ argmax_B={argmax_B}") - print(f" cos = {cos_AB:.8f}") - print(" Something downstream of the residual is magnitude-sensitive.") - print(" Check: hc_head linearity, lm_head quantization, or the final norm is misapplied.") + print(f" argmax_A={argmax_A} ≠ argmax_B={argmax_B}, cos={cos_AB:.8f}") + print(" Something downstream is magnitude-sensitive.") else: print(f" Inconclusive: argmax matches but cos={cos_AB:.8f} < 0.999") - print(" Logits are similar but not identical at different |X| scales.") - print(" The magnitude does have SOME effect, but may not be the primary cause.") print(f"{'='*70}") if __name__ == "__main__":