disable diagnostics for clean production run

This commit is contained in:
2026-05-31 03:32:17 +00:00
parent 3f9b441428
commit 171a9e0d10

View File

@@ -428,7 +428,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# -- mHC pre_block (attention) --
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H)
if li < 3 or li >= 58:
if False: # diag disabled
A_l = None
B_l, C_l = attn_ctx
print(f" L{li} pre_attn: |X_l|={X_l.abs().max().item():.2f} |x_in|={x_in.abs().max().item():.2f}", flush=True)
@@ -513,7 +513,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# -- mHC post_block (attention) --
X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
# Diagnostic: check mHC is stabilizing the residual
if li < 3 or li >= 58:
if False: # Disable diagnostics for production run
B_l, C_l = attn_ctx
print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}", flush=True)
@@ -532,7 +532,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# -- mHC post_block (FFN) --
X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H)
if li < 3 or li >= 58:
if False: # diag disabled
B_l_ffn, C_l_ffn = ffn_ctx
print(f" L{li} ffn: |X_mid|={X_mid.abs().max().item():.2f} |F_ffn|={F_ffn.abs().max().item():.2f} |B|={B_l_ffn.abs().max().item():.4f} |C|={C_l_ffn.abs().max().item():.4f} |X_next|={X_next.abs().max().item():.2f}", flush=True)