Fix sink logits shape: (n_h, T, 1) for concatenation with (n_h, T, seq_len)

This commit is contained in:
2026-05-31 11:57:23 +00:00
parent 0f951a0b1a
commit 581c4170f9

View File

@@ -493,12 +493,13 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
scores_raw = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len)
if sink_key in w and seq_len > 0:
sinks = w[sink_key].to(device=device) # (n_h,) BF16
sink_logits = sinks.reshape(1, -1, 1, 1).expand(q_input.shape[0], -1, q_input.shape[-2], -1)
combined_logits = torch.cat([scores_raw.unsqueeze(0), sink_logits.to(scores_raw.dtype)], dim=-1)
# sinks: (n_h,) → reshape to (n_h, 1, 1) for broadcasting with (n_h, T, seq_len)
sink_logits = sinks.float().reshape(n_h, 1, 1).expand(-1, T, 1)
combined_logits = torch.cat([scores_raw, sink_logits], dim=-1) # (n_h, T, seq_len+1)
# Stable softmax
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
probs = torch.softmax(combined_logits.float(), dim=-1).to(torch.bfloat16)
attn_weights = probs[..., :-1] # Drop sink column
attn_weights = probs[..., :-1] # Drop sink column (n_h, T, seq_len)
else:
attn_weights = torch.softmax(scores_raw.float(), dim=-1).to(torch.bfloat16)
attn_out = torch.matmul(attn_weights, v_expanded) # (n_h, T, hd)