diff --git a/single_shot_inference.py b/single_shot_inference.py index 29f0dfe0..81626e45 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -493,12 +493,13 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, scores_raw = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len) if sink_key in w and seq_len > 0: sinks = w[sink_key].to(device=device) # (n_h,) BF16 - sink_logits = sinks.reshape(1, -1, 1, 1).expand(q_input.shape[0], -1, q_input.shape[-2], -1) - combined_logits = torch.cat([scores_raw.unsqueeze(0), sink_logits.to(scores_raw.dtype)], dim=-1) + # sinks: (n_h,) → reshape to (n_h, 1, 1) for broadcasting with (n_h, T, seq_len) + sink_logits = sinks.float().reshape(n_h, 1, 1).expand(-1, T, 1) + combined_logits = torch.cat([scores_raw, sink_logits], dim=-1) # (n_h, T, seq_len+1) # Stable softmax combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = torch.softmax(combined_logits.float(), dim=-1).to(torch.bfloat16) - attn_weights = probs[..., :-1] # Drop sink column + attn_weights = probs[..., :-1] # Drop sink column (n_h, T, seq_len) else: attn_weights = torch.softmax(scores_raw.float(), dim=-1).to(torch.bfloat16) attn_out = torch.matmul(attn_weights, v_expanded) # (n_h, T, hd)