Fix full_attention_reference: broadcast KV to all heads+positions

This commit is contained in:
2026-05-19 07:59:28 +00:00
parent 910015c47e
commit dd3a12bbda

View File

@@ -393,10 +393,11 @@ def full_attention_reference(
T, NH, HD = q.shape
# K=V from kv latent (shared across heads, so expand)
k = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() # (T, NH, HD)
# kv: (T, HD) → broadcast to all heads and all query positions
k = kv.unsqueeze(0).unsqueeze(2).expand(T, NH, T, -1).contiguous() # (T, NH, T, HD)
v = k.clone()
# Reshape for SDPA: (T*NH, 1, HD) and (T*NH, T, HD)
# Reshape for SDPA: Q (T*NH, 1, HD), K (T*NH, T, HD), V (T*NH, T, HD)
q_2d = q.reshape(T * NH, 1, HD)
k_2d = k.reshape(T * NH, T, HD)
v_2d = v.reshape(T * NH, T, HD)