diff --git a/single_shot_inference.py b/single_shot_inference.py index e1269360..859efa92 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -145,10 +145,13 @@ class mHCBlock: n = self.n_hc dev = self.device - # fn rows: [W_pre(4), W_post(4), W_res(16)] - W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() - W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous() - W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous() + # fn rows: [W_pre(4), W_res(16), W_post(4)] — matches _dynamic_params + # A_raw = proj[:, 0:4] ← W_pre + # B_raw = proj[:, 4:20] ← W_res + # C_raw = proj[:, 20:24] ← W_post + W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() # fn[0:4] + W_res = fn[n:n+n*n].to(device=dev, dtype=torch.float32).contiguous() # fn[4:20] + W_post = fn[n+n*n:].to(device=dev, dtype=torch.float32).contiguous() # fn[20:24] # base: [S_pre(4), S_post(4), S_res(16)] S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous()