From ce3d6069cc99f70a8d2b713b56dd2704fb435550 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 10:07:14 +0000 Subject: [PATCH] CRITICAL FIX: mHC base/scale ordering matches fn ordering [pre, res, post] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All three mHC parameter tensors (fn, base, scale) share the same ordering as _dynamic_params' A/B/C split: [pre(4), res(16), post(4)]. Previous code loaded base as [pre(4), post(4), res(16)] and scale as [alpha_pre, alpha_post, alpha_res] — swapping S_res and S_post, and alpha_res and alpha_post. This caused the Sinkhorn-Knopp B_l matrix to be computed with wrong bias values, allowing the residual to explode. Also: added MHC_DIAG flag for per-layer diagnostics (B_l row/col sums, C_l values) to verify doubly-stochastic constraint is satisfied. --- single_shot_inference.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 859efa92..66bcfc12 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -61,6 +61,7 @@ PROMPT = _args.prompt or "The capital of France is" NUM_GPUS = 8 SKIP_ROUTED_MOE = _args.skip_moe # If True, only use shared expert (debug) INVERSE_ROPE = not _args.no_inverse_rope # If False, skip inverse RoPE on attention output (diagnostic) +MHC_DIAG = False # If True, print per-layer mHC diagnostics (B_l row/col sums, C_l values) # When True: applies inverse RoPE at query position → converts absolute→relative # When False: leaves relative position encoding intact for output projection # DSV4 partial RoPE only affects last 64/512 dims; first 448 are always un-RoPE'd @@ -153,15 +154,18 @@ class mHCBlock: W_res = fn[n:n+n*n].to(device=dev, dtype=torch.float32).contiguous() # fn[4:20] W_post = fn[n+n*n:].to(device=dev, dtype=torch.float32).contiguous() # fn[20:24] - # base: [S_pre(4), S_post(4), S_res(16)] + # base: [S_pre(4), S_res(16), S_post(4)] — matches fn ordering [A, B, C] + # The checkpoint stores all 3 arrays (fn, base, scale) in the same + # [pre, res, post] order matching _dynamic_params' A/B/C split. + # Previous note "[pre, post, res]" was incorrect for base/scale. S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous() - S_post = base[n:2*n].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous() - S_res = base[2*n:].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous() + S_res = base[n:n+n*n].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous() # base[4:20] + S_post = base[n+n*n:].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous() # base[20:24] - # scale: [alpha_pre, alpha_post, alpha_res] + # scale: [alpha_pre, alpha_res, alpha_post] — matches [A, B, C] ordering alpha_pre = scale[0].item() - alpha_post = scale[1].item() - alpha_res = scale[2].item() + alpha_res = scale[1].item() + alpha_post = scale[2].item() self._impl.load_weights( W_pre=W_pre, W_res=W_res, W_post=W_post, @@ -347,7 +351,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # -- mHC pre_block (attention) -- x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) - if False: # diag disabled + if MHC_DIAG: # mHC diagnostics A_l = None B_l, C_l = attn_ctx print(f" L{li} pre_attn: |X_l|={X_l.abs().max().item():.2f} |x_in|={x_in.abs().max().item():.2f}", flush=True) @@ -482,9 +486,14 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # -- mHC post_block (attention) -- X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H) # Diagnostic: check mHC is stabilizing the residual - if False: # Disable diagnostics for production run + if MHC_DIAG: # mHC diagnostics B_l, C_l = attn_ctx - print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}", flush=True) + print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}") + # Check B_l is doubly stochastic (rows sum to 1.0) + B_row_sums = B_l.sum(dim=-1) # (T, n_hc) + B_col_sums = B_l.sum(dim=-2) # (T, n_hc) + print(f" B row_sums={B_row_sums[0].tolist()} col_sums={B_col_sums[0].tolist()}") + print(f" C_l={C_l[0].tolist()}") # ================================================================== # FFN SUB-BLOCK @@ -501,7 +510,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # -- mHC post_block (FFN) -- X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H) - if False: # diag disabled + if MHC_DIAG: # ffn mHC diagnostics B_l_ffn, C_l_ffn = ffn_ctx print(f" L{li} ffn: |X_mid|={X_mid.abs().max().item():.2f} |F_ffn|={F_ffn.abs().max().item():.2f} |B|={B_l_ffn.abs().max().item():.4f} |C|={C_l_ffn.abs().max().item():.4f} |X_next|={X_next.abs().max().item():.2f}", flush=True)