Fix mHCContext attribute access (not tuple unpacking) and enable attention diag
This commit is contained in:
@@ -370,8 +370,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# -- mHC pre_block (attention) --
|
||||
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H)
|
||||
if MHC_DIAG: # mHC diagnostics
|
||||
A_l = None
|
||||
B_l, C_l = attn_ctx
|
||||
B_l, C_l = attn_ctx.B_l, attn_ctx.C_l
|
||||
print(f" L{li} pre_attn: |X_l|={X_l.abs().max().item():.2f} |x_in|={x_in.abs().max().item():.2f}", flush=True)
|
||||
|
||||
# -- RMSNorm (pre-norm before attention) --
|
||||
@@ -461,7 +460,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
q_input, k_expanded, v_expanded, scale=scale, is_causal=False)
|
||||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
# Diagnostic: check attention entropy (how spread out the attention is)
|
||||
if False: # MHC_DIAG
|
||||
if MHC_DIAG and li < 3:
|
||||
with torch.no_grad():
|
||||
scores = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len)
|
||||
weights = torch.softmax(scores.float(), dim=-1) # (n_h, 1, seq_len)
|
||||
@@ -520,7 +519,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
|
||||
# Diagnostic: check mHC is stabilizing the residual
|
||||
if MHC_DIAG: # mHC diagnostics
|
||||
B_l, C_l = attn_ctx
|
||||
B_l, C_l = attn_ctx.B_l, attn_ctx.C_l
|
||||
print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}")
|
||||
# Check B_l is doubly stochastic (rows sum to 1.0)
|
||||
B_row_sums = B_l.sum(dim=-1) # (T, n_hc)
|
||||
@@ -544,7 +543,7 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# -- mHC post_block (FFN) --
|
||||
X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H)
|
||||
if MHC_DIAG: # ffn mHC diagnostics
|
||||
B_l_ffn, C_l_ffn = ffn_ctx
|
||||
B_l_ffn, C_l_ffn = ffn_ctx.B_l, ffn_ctx.C_l
|
||||
print(f" L{li} ffn: |X_mid|={X_mid.abs().max().item():.2f} |F_ffn|={F_ffn.abs().max().item():.2f} |B|={B_l_ffn.abs().max().item():.4f} |C|={C_l_ffn.abs().max().item():.4f} |X_next|={X_next.abs().max().item():.2f}", flush=True)
|
||||
|
||||
return X_next
|
||||
|
||||
Reference in New Issue
Block a user