fix verify_attention: proper multi-head SDPA + GQA
This commit is contained in:
@@ -169,6 +169,40 @@ def main():
|
||||
print(f"\nInverse RoPE recovery: max diff = {diff:.6f} (should be ~0)")
|
||||
|
||||
# === Output projection ===
|
||||
# For GQA, the attention output is (n_h, T, hd)
|
||||
# Each Q head attended to the same KV, producing its own output
|
||||
# For this test with 1 KV entry, all heads produce the same V
|
||||
# In practice, each head has different Q, so different attention weights
|
||||
# Let's use the actual multi-head attention output
|
||||
|
||||
# Proper multi-head attention with SDPA
|
||||
q_input = q_roped # (1, n_h, hd) = (1, 128, 512)
|
||||
k_input = kv_roped.expand(n_h, -1, -1) # (n_h, 1, hd) = (128, 1, 512)
|
||||
v_input = kv_roped.expand(n_h, -1, -1)
|
||||
|
||||
attn_out_full = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_input, v_input,
|
||||
scale=1.0 / math.sqrt(hd), is_causal=False) # (1, n_h, hd)
|
||||
# Wait, shapes are (1, 128, 512) for q but (128, 1, 512) for k/v
|
||||
# Need to fix: q is (1, n_h, hd) → permute to (n_h, 1, hd)
|
||||
q_input = q_roped.permute(1, 0, 2) # (n_h, 1, hd) = (128, 1, 512)
|
||||
k_input = kv_roped.squeeze(0).expand(n_h, -1, -1) # (n_h, 1, hd)
|
||||
v_input = kv_roped.squeeze(0).expand(n_h, -1, -1) # (n_h, 1, hd)
|
||||
|
||||
attn_out_full = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_input, v_input,
|
||||
scale=1.0 / math.sqrt(hd), is_causal=False) # (n_h, 1, hd)
|
||||
attn_out_full = attn_out_full.permute(1, 0, 2) # (1, n_h, hd)
|
||||
|
||||
# Inverse RoPE
|
||||
attn_out_inv = apply_inverse_rope(attn_out_full, cos.unsqueeze(0), sin.unsqueeze(0))
|
||||
print(f"\nMulti-head attn output: shape={attn_out_inv.shape}, |attn|={attn_out_inv.abs().max():.4f}")
|
||||
|
||||
# Compare: per-head output should be close to kv (since 1 KV entry)
|
||||
for h in [0, 1, 63, 127]:
|
||||
diff = (attn_out_inv[0, h].float() - kv_heads[0, 0].float()).abs().max()
|
||||
print(f" Head {h}: max diff from kv = {diff:.6f}")
|
||||
|
||||
attn_flat = attn_out_inv.reshape(1, n_h * hd) # (1, 65536)
|
||||
print(f"\nattn_flat: shape={attn_flat.shape}, |attn_flat|={attn_flat.abs().max():.4f}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user