Fix kv shape: expand to (T, NH, HD) before reshape
This commit is contained in:
@@ -392,9 +392,9 @@ def full_attention_reference(
|
||||
"""
|
||||
T, NH, HD = q.shape
|
||||
|
||||
# K=V from kv latent (MLA-style: single KV, shared across heads)
|
||||
k = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD)
|
||||
v = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD)
|
||||
# K=V from kv latent (shared across heads, so expand)
|
||||
k = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() # (T, NH, HD)
|
||||
v = k.clone()
|
||||
|
||||
# Reshape for SDPA: (T*NH, 1, HD) and (T*NH, T, HD)
|
||||
q_2d = q.reshape(T * NH, 1, HD)
|
||||
|
||||
Reference in New Issue
Block a user