From 9a43e9aa7747f7ac3fb52e782dc09d4b5ceed3bc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 10:02:57 +0000 Subject: [PATCH] CRITICAL FIX: mHC fn weight row ordering was wrong MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fn rows are [W_pre(4), W_res(16), W_post(4)] matching [A_raw, B_raw, C_raw] in _dynamic_params. Was loading as [W_pre(4), W_post(4), W_res(16)] which shifted W_res rows by 4 and loaded wrong rows as W_post. This caused the Sinkhorn-Knopp B_l matrix to be computed from wrong weights, allowing the residual to explode (|X| 0.8 → 160K across 61 layers). Correct: fn[0:4]=W_pre, fn[4:20]=W_res, fn[20:24]=W_post Wrong: fn[0:4]=W_pre, fn[4:8]=W_post, fn[8:24]=W_res --- single_shot_inference.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index e1269360..859efa92 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -145,10 +145,13 @@ class mHCBlock: n = self.n_hc dev = self.device - # fn rows: [W_pre(4), W_post(4), W_res(16)] - W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() - W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous() - W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous() + # fn rows: [W_pre(4), W_res(16), W_post(4)] — matches _dynamic_params + # A_raw = proj[:, 0:4] ← W_pre + # B_raw = proj[:, 4:20] ← W_res + # C_raw = proj[:, 20:24] ← W_post + W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() # fn[0:4] + W_res = fn[n:n+n*n].to(device=dev, dtype=torch.float32).contiguous() # fn[4:20] + W_post = fn[n+n*n:].to(device=dev, dtype=torch.float32).contiguous() # fn[20:24] # base: [S_pre(4), S_post(4), S_res(16)] S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous()