From 910015c47e532ea9029e086f7141cdbaddbffe22 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 07:58:42 +0000 Subject: [PATCH] Fix kv shape: expand to (T, NH, HD) before reshape --- cutedsl/csa_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cutedsl/csa_attention.py b/cutedsl/csa_attention.py index 99d73f89..87273e60 100644 --- a/cutedsl/csa_attention.py +++ b/cutedsl/csa_attention.py @@ -392,9 +392,9 @@ def full_attention_reference( """ T, NH, HD = q.shape - # K=V from kv latent (MLA-style: single KV, shared across heads) - k = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD) - v = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD) + # K=V from kv latent (shared across heads, so expand) + k = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() # (T, NH, HD) + v = k.clone() # Reshape for SDPA: (T*NH, 1, HD) and (T*NH, T, HD) q_2d = q.reshape(T * NH, 1, HD)