Fix kv shape: expand to (T, NH, HD) before reshape

This commit is contained in:
2026-05-19 07:58:42 +00:00
parent 3de75c4e37
commit 910015c47e

View File

@@ -392,9 +392,9 @@ def full_attention_reference(
"""
T, NH, HD = q.shape
# K=V from kv latent (MLA-style: single KV, shared across heads)
k = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD)
v = kv.unsqueeze(1).expand(-1, NH, -1) # (T, NH, HD)
# K=V from kv latent (shared across heads, so expand)
k = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() # (T, NH, HD)
v = k.clone()
# Reshape for SDPA: (T*NH, 1, HD) and (T*NH, T, HD)
q_2d = q.reshape(T * NH, 1, HD)