From dd50c355a69237d600ffb15a7cf527b3c4954ccd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 11:37:32 +0000 Subject: [PATCH] Fix MHC_DIAG null check when SKIP_MHC is enabled --- single_shot_inference.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index a1a67e20..742495aa 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -406,7 +406,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, else: # -- mHC pre_block (attention) -- x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) - if MHC_DIAG: # mHC diagnostics + if MHC_DIAG and attn_ctx is not None: # mHC diagnostics B_l, C_l = attn_ctx.B_l, attn_ctx.C_l print(f" L{li} pre_attn: |X_l|={X_l.abs().max().item():.2f} |x_in|={x_in.abs().max().item():.2f}", flush=True) @@ -558,7 +558,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # -- mHC post_block (attention) -- X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H) # Diagnostic: check mHC is stabilizing the residual - if MHC_DIAG: # mHC diagnostics + if MHC_DIAG and attn_ctx is not None: # mHC diagnostics B_l, C_l = attn_ctx.B_l, attn_ctx.C_l print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}") # Check B_l is doubly stochastic (rows sum to 1.0) @@ -589,7 +589,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, else: # -- mHC post_block (FFN) -- X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H) - if MHC_DIAG: # ffn mHC diagnostics + if MHC_DIAG and ffn_ctx is not None: # ffn mHC diagnostics B_l_ffn, C_l_ffn = ffn_ctx.B_l, ffn_ctx.C_l print(f" L{li} ffn: |X_mid|={X_mid.abs().max().item():.2f} |F_ffn|={F_ffn.abs().max().item():.2f} |B|={B_l_ffn.abs().max().item():.4f} |C|={C_l_ffn.abs().max().item():.4f} |X_next|={X_next.abs().max().item():.2f}", flush=True)