diff --git a/tests/verify_attention.py b/tests/verify_attention.py index 816360e5..3998844b 100644 --- a/tests/verify_attention.py +++ b/tests/verify_attention.py @@ -169,6 +169,40 @@ def main(): print(f"\nInverse RoPE recovery: max diff = {diff:.6f} (should be ~0)") # === Output projection === + # For GQA, the attention output is (n_h, T, hd) + # Each Q head attended to the same KV, producing its own output + # For this test with 1 KV entry, all heads produce the same V + # In practice, each head has different Q, so different attention weights + # Let's use the actual multi-head attention output + + # Proper multi-head attention with SDPA + q_input = q_roped # (1, n_h, hd) = (1, 128, 512) + k_input = kv_roped.expand(n_h, -1, -1) # (n_h, 1, hd) = (128, 1, 512) + v_input = kv_roped.expand(n_h, -1, -1) + + attn_out_full = torch.nn.functional.scaled_dot_product_attention( + q_input, k_input, v_input, + scale=1.0 / math.sqrt(hd), is_causal=False) # (1, n_h, hd) + # Wait, shapes are (1, 128, 512) for q but (128, 1, 512) for k/v + # Need to fix: q is (1, n_h, hd) → permute to (n_h, 1, hd) + q_input = q_roped.permute(1, 0, 2) # (n_h, 1, hd) = (128, 1, 512) + k_input = kv_roped.squeeze(0).expand(n_h, -1, -1) # (n_h, 1, hd) + v_input = kv_roped.squeeze(0).expand(n_h, -1, -1) # (n_h, 1, hd) + + attn_out_full = torch.nn.functional.scaled_dot_product_attention( + q_input, k_input, v_input, + scale=1.0 / math.sqrt(hd), is_causal=False) # (n_h, 1, hd) + attn_out_full = attn_out_full.permute(1, 0, 2) # (1, n_h, hd) + + # Inverse RoPE + attn_out_inv = apply_inverse_rope(attn_out_full, cos.unsqueeze(0), sin.unsqueeze(0)) + print(f"\nMulti-head attn output: shape={attn_out_inv.shape}, |attn|={attn_out_inv.abs().max():.4f}") + + # Compare: per-head output should be close to kv (since 1 KV entry) + for h in [0, 1, 63, 127]: + diff = (attn_out_inv[0, h].float() - kv_heads[0, 0].float()).abs().max() + print(f" Head {h}: max diff from kv = {diff:.6f}") + attn_flat = attn_out_inv.reshape(1, n_h * hd) # (1, 65536) print(f"\nattn_flat: shape={attn_flat.shape}, |attn_flat|={attn_flat.abs().max():.4f}")