diff --git a/single_shot_inference.py b/single_shot_inference.py index 552357bd..e7c3e64d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -376,9 +376,10 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, if seq_len < 120: k_expanded = k_full.expand(n_h, -1, -1).contiguous() v_expanded = v_full.expand(n_h, -1, -1).contiguous() - # Add attention sink (paper D5c) + # Attention sink (paper D5c) + # DISABLED for now to check impact sink_key = f"{pre}.sinks" - if sink_key in w and seq_len > 0: + if False and sink_key in w and seq_len > 0: sinks = w[sink_key].to(device=device) # (n_h,) BF16 sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device) sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)