CRITICAL FIX: mHC fn weight row ordering was wrong

fn rows are [W_pre(4), W_res(16), W_post(4)] matching [A_raw, B_raw, C_raw]
in _dynamic_params. Was loading as [W_pre(4), W_post(4), W_res(16)] which
shifted W_res rows by 4 and loaded wrong rows as W_post. This caused the
Sinkhorn-Knopp B_l matrix to be computed from wrong weights, allowing the
residual to explode (|X| 0.8 → 160K across 61 layers).

Correct: fn[0:4]=W_pre, fn[4:20]=W_res, fn[20:24]=W_post
Wrong:   fn[0:4]=W_pre, fn[4:8]=W_post, fn[8:24]=W_res
This commit is contained in:
2026-05-31 10:02:57 +00:00
parent 0346e479d4
commit 9a43e9aa77

View File

@@ -145,10 +145,13 @@ class mHCBlock:
n = self.n_hc
dev = self.device
# fn rows: [W_pre(4), W_post(4), W_res(16)]
W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous()
W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous()
W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous()
# fn rows: [W_pre(4), W_res(16), W_post(4)] — matches _dynamic_params
# A_raw = proj[:, 0:4] ← W_pre
# B_raw = proj[:, 4:20] ← W_res
# C_raw = proj[:, 20:24] ← W_post
W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() # fn[0:4]
W_res = fn[n:n+n*n].to(device=dev, dtype=torch.float32).contiguous() # fn[4:20]
W_post = fn[n+n*n:].to(device=dev, dtype=torch.float32).contiguous() # fn[20:24]
# base: [S_pre(4), S_post(4), S_res(16)]
S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous()