P5: fix reference attention for MQA/GQA (kv_idx = h // q_per_kv)
This commit is contained in:
@@ -25,20 +25,24 @@ def cosine_sim(a, b):
|
||||
def reference_attention(q_4d, k_4d, v_4d, scale):
|
||||
"""PyTorch reference matching kernel tensor layout.
|
||||
|
||||
Q: (1, n_h, 1, hd), K: (1, n_h, N, hd), V: (1, n_h, hd, N)
|
||||
Q: (1, n_h, 1, hd), K: (1, n_kv, N, hd), V: (1, n_kv, hd, N)
|
||||
V is in kernel layout (hd, N) — transpose to (N, hd) for reference.
|
||||
For MQA/GQA, each Q head uses its corresponding KV head.
|
||||
"""
|
||||
n_h = q_4d.shape[1]
|
||||
N = k_4d.shape[2] # total KV length (may be > 128)
|
||||
n_kv = k_4d.shape[1]
|
||||
N = k_4d.shape[2]
|
||||
q_per_kv = n_h // n_kv
|
||||
q = q_4d[0] # (n_h, 1, hd)
|
||||
k = k_4d[0] # (n_h, N, hd)
|
||||
v = v_4d[0].transpose(-1, -2) # (n_h, N, hd)
|
||||
k = k_4d[0] # (n_kv, N, hd)
|
||||
v = v_4d[0].transpose(-1, -2) # (n_kv, N, hd)
|
||||
|
||||
output = torch.zeros(n_h, 1, q_4d.shape[3], dtype=torch.bfloat16, device='cuda')
|
||||
for h in range(n_h):
|
||||
kv_idx = h // q_per_kv
|
||||
q_h = q[h] # (1, hd)
|
||||
k_h = k[h] # (N, hd)
|
||||
v_h = v[h] # (N, hd)
|
||||
k_h = k[kv_idx] # (N, hd)
|
||||
v_h = v[kv_idx] # (N, hd)
|
||||
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
||||
s = torch.softmax(s, dim=-1)
|
||||
o = torch.matmul(s, v_h.float())
|
||||
|
||||
Reference in New Issue
Block a user