P5: fix reference attention for MQA/GQA (kv_idx = h // q_per_kv)

This commit is contained in:
2026-05-30 08:59:50 +00:00
parent c55030a340
commit b61df2657b

View File

@@ -25,20 +25,24 @@ def cosine_sim(a, b):
def reference_attention(q_4d, k_4d, v_4d, scale):
"""PyTorch reference matching kernel tensor layout.
Q: (1, n_h, 1, hd), K: (1, n_h, N, hd), V: (1, n_h, hd, N)
Q: (1, n_h, 1, hd), K: (1, n_kv, N, hd), V: (1, n_kv, hd, N)
V is in kernel layout (hd, N) — transpose to (N, hd) for reference.
For MQA/GQA, each Q head uses its corresponding KV head.
"""
n_h = q_4d.shape[1]
N = k_4d.shape[2] # total KV length (may be > 128)
n_kv = k_4d.shape[1]
N = k_4d.shape[2]
q_per_kv = n_h // n_kv
q = q_4d[0] # (n_h, 1, hd)
k = k_4d[0] # (n_h, N, hd)
v = v_4d[0].transpose(-1, -2) # (n_h, N, hd)
k = k_4d[0] # (n_kv, N, hd)
v = v_4d[0].transpose(-1, -2) # (n_kv, N, hd)
output = torch.zeros(n_h, 1, q_4d.shape[3], dtype=torch.bfloat16, device='cuda')
for h in range(n_h):
kv_idx = h // q_per_kv
q_h = q[h] # (1, hd)
k_h = k[h] # (N, hd)
v_h = v[h] # (N, hd)
k_h = k[kv_idx] # (N, hd)
v_h = v[kv_idx] # (N, hd)
s = torch.matmul(q_h.float(), k_h.float().T) * scale
s = torch.softmax(s, dim=-1)
o = torch.matmul(s, v_h.float())