debug: disable sinks in SDPA to check |X| impact

This commit is contained in:
2026-05-31 06:51:58 +00:00
parent e3db90b56c
commit 4f28673bec

View File

@@ -376,9 +376,10 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
if seq_len < 120:
k_expanded = k_full.expand(n_h, -1, -1).contiguous()
v_expanded = v_full.expand(n_h, -1, -1).contiguous()
# Add attention sink (paper D5c)
# Attention sink (paper D5c)
# DISABLED for now to check impact
sink_key = f"{pre}.sinks"
if sink_key in w and seq_len > 0:
if False and sink_key in w and seq_len > 0:
sinks = w[sink_key].to(device=device) # (n_h,) BF16
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)