Fix full_attention_reference: broadcast KV to all heads+positions
This commit is contained in:
@@ -393,10 +393,11 @@ def full_attention_reference(
|
||||
T, NH, HD = q.shape
|
||||
|
||||
# K=V from kv latent (shared across heads, so expand)
|
||||
k = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() # (T, NH, HD)
|
||||
# kv: (T, HD) → broadcast to all heads and all query positions
|
||||
k = kv.unsqueeze(0).unsqueeze(2).expand(T, NH, T, -1).contiguous() # (T, NH, T, HD)
|
||||
v = k.clone()
|
||||
|
||||
# Reshape for SDPA: (T*NH, 1, HD) and (T*NH, T, HD)
|
||||
# Reshape for SDPA: Q (T*NH, 1, HD), K (T*NH, T, HD), V (T*NH, T, HD)
|
||||
q_2d = q.reshape(T * NH, 1, HD)
|
||||
k_2d = k.reshape(T * NH, T, HD)
|
||||
v_2d = v.reshape(T * NH, T, HD)
|
||||
|
||||
Reference in New Issue
Block a user