From b61df2657be45462b6bf6f097ecc28810e14f798 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 08:59:50 +0000 Subject: [PATCH] P5: fix reference attention for MQA/GQA (kv_idx = h // q_per_kv) --- tests/unit/test_p3_fast_decode.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_p3_fast_decode.py b/tests/unit/test_p3_fast_decode.py index 332829bf..c9cb64d2 100644 --- a/tests/unit/test_p3_fast_decode.py +++ b/tests/unit/test_p3_fast_decode.py @@ -25,20 +25,24 @@ def cosine_sim(a, b): def reference_attention(q_4d, k_4d, v_4d, scale): """PyTorch reference matching kernel tensor layout. - Q: (1, n_h, 1, hd), K: (1, n_h, N, hd), V: (1, n_h, hd, N) + Q: (1, n_h, 1, hd), K: (1, n_kv, N, hd), V: (1, n_kv, hd, N) V is in kernel layout (hd, N) — transpose to (N, hd) for reference. + For MQA/GQA, each Q head uses its corresponding KV head. """ n_h = q_4d.shape[1] - N = k_4d.shape[2] # total KV length (may be > 128) + n_kv = k_4d.shape[1] + N = k_4d.shape[2] + q_per_kv = n_h // n_kv q = q_4d[0] # (n_h, 1, hd) - k = k_4d[0] # (n_h, N, hd) - v = v_4d[0].transpose(-1, -2) # (n_h, N, hd) + k = k_4d[0] # (n_kv, N, hd) + v = v_4d[0].transpose(-1, -2) # (n_kv, N, hd) output = torch.zeros(n_h, 1, q_4d.shape[3], dtype=torch.bfloat16, device='cuda') for h in range(n_h): + kv_idx = h // q_per_kv q_h = q[h] # (1, hd) - k_h = k[h] # (N, hd) - v_h = v[h] # (N, hd) + k_h = k[kv_idx] # (N, hd) + v_h = v[kv_idx] # (N, hd) s = torch.matmul(q_h.float(), k_h.float().T) * scale s = torch.softmax(s, dim=-1) o = torch.matmul(s, v_h.float())