Fix sink logits shape: (n_h, T, 1) for concatenation with (n_h, T, seq_len)
This commit is contained in:
@@ -493,12 +493,13 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
scores_raw = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len)
|
||||
if sink_key in w and seq_len > 0:
|
||||
sinks = w[sink_key].to(device=device) # (n_h,) BF16
|
||||
sink_logits = sinks.reshape(1, -1, 1, 1).expand(q_input.shape[0], -1, q_input.shape[-2], -1)
|
||||
combined_logits = torch.cat([scores_raw.unsqueeze(0), sink_logits.to(scores_raw.dtype)], dim=-1)
|
||||
# sinks: (n_h,) → reshape to (n_h, 1, 1) for broadcasting with (n_h, T, seq_len)
|
||||
sink_logits = sinks.float().reshape(n_h, 1, 1).expand(-1, T, 1)
|
||||
combined_logits = torch.cat([scores_raw, sink_logits], dim=-1) # (n_h, T, seq_len+1)
|
||||
# Stable softmax
|
||||
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
|
||||
probs = torch.softmax(combined_logits.float(), dim=-1).to(torch.bfloat16)
|
||||
attn_weights = probs[..., :-1] # Drop sink column
|
||||
attn_weights = probs[..., :-1] # Drop sink column (n_h, T, seq_len)
|
||||
else:
|
||||
attn_weights = torch.softmax(scores_raw.float(), dim=-1).to(torch.bfloat16)
|
||||
attn_out = torch.matmul(attn_weights, v_expanded) # (n_h, T, hd)
|
||||
|
||||
Reference in New Issue
Block a user