Add attention and Q/KV diagnostics (MHC_DIAG flag)
This commit is contained in:
@@ -409,6 +409,11 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd)
|
||||
kv_new = kv.reshape(T, 1, hd) # (T, 1, hd) — 1 KV head
|
||||
|
||||
# Diagnostic: Q/KV norms
|
||||
if MHC_DIAG and li < 3:
|
||||
print(f" L{li} Q: |q|={q_heads.abs().max().item():.2f} mean={q_heads.float().abs().mean().item():.4f}")
|
||||
print(f" L{li} KV: |kv|={kv_new.abs().max().item():.2f} mean={kv_new.float().abs().mean().item():.4f}")
|
||||
|
||||
# -- Apply RoPE to Q (at current positions) --
|
||||
positions_dev = positions.to(device)
|
||||
q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
@@ -455,6 +460,16 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_expanded, v_expanded, scale=scale, is_causal=False)
|
||||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
# Diagnostic: check attention entropy (how spread out the attention is)
|
||||
if False: # MHC_DIAG
|
||||
with torch.no_grad():
|
||||
scores = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len)
|
||||
weights = torch.softmax(scores.float(), dim=-1) # (n_h, 1, seq_len)
|
||||
# For head 0: what positions get the most weight?
|
||||
w0 = weights[0, 0] # (seq_len,)
|
||||
top3_pos = torch.topk(w0, min(3, seq_len))
|
||||
entropy = -(w0 * (w0 + 1e-10).log()).sum().item()
|
||||
print(f" L{li} attn: seq_len={seq_len} entropy={entropy:.2f} top3_pos={top3_pos.indices.tolist()} top3_w={top3_pos.values.tolist()}")
|
||||
else:
|
||||
# Use FMHA kernel for longer sequences (padding effect is negligible)
|
||||
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
||||
|
||||
Reference in New Issue
Block a user