CRITICAL FIX: mHC fn weight row ordering was wrong
fn rows are [W_pre(4), W_res(16), W_post(4)] matching [A_raw, B_raw, C_raw] in _dynamic_params. Was loading as [W_pre(4), W_post(4), W_res(16)] which shifted W_res rows by 4 and loaded wrong rows as W_post. This caused the Sinkhorn-Knopp B_l matrix to be computed from wrong weights, allowing the residual to explode (|X| 0.8 → 160K across 61 layers). Correct: fn[0:4]=W_pre, fn[4:20]=W_res, fn[20:24]=W_post Wrong: fn[0:4]=W_pre, fn[4:8]=W_post, fn[8:24]=W_res
This commit is contained in:
@@ -145,10 +145,13 @@ class mHCBlock:
|
||||
n = self.n_hc
|
||||
dev = self.device
|
||||
|
||||
# fn rows: [W_pre(4), W_post(4), W_res(16)]
|
||||
W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous()
|
||||
W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous()
|
||||
W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous()
|
||||
# fn rows: [W_pre(4), W_res(16), W_post(4)] — matches _dynamic_params
|
||||
# A_raw = proj[:, 0:4] ← W_pre
|
||||
# B_raw = proj[:, 4:20] ← W_res
|
||||
# C_raw = proj[:, 20:24] ← W_post
|
||||
W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() # fn[0:4]
|
||||
W_res = fn[n:n+n*n].to(device=dev, dtype=torch.float32).contiguous() # fn[4:20]
|
||||
W_post = fn[n+n*n:].to(device=dev, dtype=torch.float32).contiguous() # fn[20:24]
|
||||
|
||||
# base: [S_pre(4), S_post(4), S_res(16)]
|
||||
S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user