Fix attention sinks: logit bias (HuggingFace reference), not dummy KV
The HuggingFace reference treats attention sinks as a logit bias: 1. Compute raw Q*K scores 2. Concatenate sinks as a logit column 3. Softmax the combined logits 4. DROP the sink column (don't multiply by V) 5. Multiply by V Our old code added sinks as a dummy zero-KV entry, which diluted attention weights by adding an extra V=0 position to the softmax.
This commit is contained in:
@@ -485,23 +485,23 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
if seq_len < 120:
|
||||
k_expanded = k_full.expand(n_h, -1, -1).contiguous()
|
||||
v_expanded = v_full.expand(n_h, -1, -1).contiguous()
|
||||
# Attention sink (paper D5c)
|
||||
# Attention: compute raw scores, add sinks as logit bias, softmax, multiply by V
|
||||
# (paper D5c, matching HuggingFace reference implementation)
|
||||
# Sinks are added as a logit column, softmaxed together, then DROPPED
|
||||
# before V multiplication — NOT as a dummy KV entry.
|
||||
sink_key = f"{pre}.sinks"
|
||||
scores_raw = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len)
|
||||
if sink_key in w and seq_len > 0:
|
||||
sinks = w[sink_key].to(device=device) # (n_h,) BF16
|
||||
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||||
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||||
k_with_sink = torch.cat([k_expanded, sink_k], dim=1)
|
||||
v_with_sink = torch.cat([v_expanded, sink_v], dim=1)
|
||||
sink_bias_mask = torch.zeros(n_h, T, seq_len + 1, dtype=torch.bfloat16, device=device)
|
||||
for h in range(n_h):
|
||||
sink_bias_mask[h, :, -1] = sinks[h]
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_with_sink, v_with_sink,
|
||||
attn_mask=sink_bias_mask, scale=scale)
|
||||
sink_logits = sinks.reshape(1, -1, 1, 1).expand(q_input.shape[0], -1, q_input.shape[-2], -1)
|
||||
combined_logits = torch.cat([scores_raw.unsqueeze(0), sink_logits.to(scores_raw.dtype)], dim=-1)
|
||||
# Stable softmax
|
||||
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
||||
probs = torch.softmax(combined_logits.float(), dim=-1).to(torch.bfloat16)
|
||||
attn_weights = probs[..., :-1] # Drop sink column
|
||||
else:
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_expanded, v_expanded, scale=scale, is_causal=False)
|
||||
attn_weights = torch.softmax(scores_raw.float(), dim=-1).to(torch.bfloat16)
|
||||
attn_out = torch.matmul(attn_weights, v_expanded) # (n_h, T, hd)
|
||||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
# Diagnostic: check attention entropy (how spread out the attention is)
|
||||
if MHC_DIAG and li < 3:
|
||||
|
||||
Reference in New Issue
Block a user