debug: disable sinks in SDPA to check |X| impact
This commit is contained in:
@@ -376,9 +376,10 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
if seq_len < 120:
|
||||
k_expanded = k_full.expand(n_h, -1, -1).contiguous()
|
||||
v_expanded = v_full.expand(n_h, -1, -1).contiguous()
|
||||
# Add attention sink (paper D5c)
|
||||
# Attention sink (paper D5c)
|
||||
# DISABLED for now to check impact
|
||||
sink_key = f"{pre}.sinks"
|
||||
if sink_key in w and seq_len > 0:
|
||||
if False and sink_key in w and seq_len > 0:
|
||||
sinks = w[sink_key].to(device=device) # (n_h,) BF16
|
||||
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||||
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||||
|
||||
Reference in New Issue
Block a user