From dd3a12bbda6469c7c7bef1aa8ff2b5403dd3e743 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 07:59:28 +0000 Subject: [PATCH] Fix full_attention_reference: broadcast KV to all heads+positions --- cutedsl/csa_attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cutedsl/csa_attention.py b/cutedsl/csa_attention.py index 87273e60..7f30eab9 100644 --- a/cutedsl/csa_attention.py +++ b/cutedsl/csa_attention.py @@ -393,10 +393,11 @@ def full_attention_reference( T, NH, HD = q.shape # K=V from kv latent (shared across heads, so expand) - k = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() # (T, NH, HD) + # kv: (T, HD) → broadcast to all heads and all query positions + k = kv.unsqueeze(0).unsqueeze(2).expand(T, NH, T, -1).contiguous() # (T, NH, T, HD) v = k.clone() - # Reshape for SDPA: (T*NH, 1, HD) and (T*NH, T, HD) + # Reshape for SDPA: Q (T*NH, 1, HD), K (T*NH, T, HD), V (T*NH, T, HD) q_2d = q.reshape(T * NH, 1, HD) k_2d = k.reshape(T * NH, T, HD) v_2d = v.reshape(T * NH, T, HD)