diff --git a/cutedsl/csa_attention.py b/cutedsl/csa_attention.py index 99d73f89..87273e60 100644 --- a/cutedsl/csa_attention.py +++ b/cutedsl/csa_attention.py @@ -392,9 +392,9 @@ def full_attention_reference( """ T, NH, HD = q.shape - # K=V from kv latent (MLA-style: single KV, shared across heads) - k = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD) - v = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD) + # K=V from kv latent (shared across heads, so expand) + k = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() # (T, NH, HD) + v = k.clone() # Reshape for SDPA: (T*NH, 1, HD) and (T*NH, T, HD) q_2d = q.reshape(T * NH, 1, HD)