debug: single head sanity test with known values

This commit is contained in:
2026-05-30 08:25:20 +00:00
parent 1c74d35fb4
commit cfac224b59

View File

@@ -0,0 +1,67 @@
"""
Absolute simplest test: single head, small N, verify kernel == reference.
"""
import torch
import math
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
def test_single_head_sanity():
"""Single head, N=128, hd=64. Known values, no randomness."""
hd = 64
N = 128
scale = 1.0 / math.sqrt(hd)
# Q: (1, 1, 1, hd) — single head, single query token
q = torch.ones(1, 1, 1, hd, dtype=torch.bfloat16, device='cuda')
# K: (1, 1, N, hd) — single KV head, N positions
k = torch.ones(1, 1, N, hd, dtype=torch.bfloat16, device='cuda')
# V: (1, 1, hd, N) — in kernel layout
# Let's make V[d, r] = d + r*0.01 (simple pattern)
v_data = torch.arange(hd, dtype=torch.float32, device='cuda').unsqueeze(1) + \
torch.arange(N, dtype=torch.float32, device='cuda').unsqueeze(0) * 0.01
v_4d = v_data.bfloat16().unsqueeze(0).unsqueeze(0) # (1, 1, hd, N)
sb = torch.zeros(1, 1, dtype=torch.float32, device='cuda')
o_4d, lse = fmha_multihead_decode_raw(q, k, v_4d, scale, 0, 0, False, sb)
# Reference: Q is all-ones, K is all-ones, so QK^T gives all-equal scores
# softmax of uniform = 1/N. So O = (1/N) * sum(V[r, d] for r in 0..N-1)
v_ref = v_data.T # (N, hd) — reference uses (N, hd) layout
# Each V[r, d] = d + r*0.01
# sum over r: sum(d + r*0.01) = N*d + 0.01*sum(r) = N*d + 0.01*N*(N-1)/2
# O[d] = (1/N) * (N*d + 0.01*N*(N-1)/2) = d + 0.01*(N-1)/2
o_expected = torch.arange(hd, dtype=torch.float32, device='cuda') + 0.01 * (N - 1) / 2
cos = torch.nn.functional.cosine_similarity(
o_4d[0, 0].float().flatten().unsqueeze(0),
o_expected.flatten().unsqueeze(0),
).item()
# Also compute via direct matmul for sanity
q_f = q.float().squeeze() # (hd,) all ones
k_f = k.float().squeeze() # (N, hd) all ones
v_f = v_ref # (N, hd)
scores = torch.matmul(q_f.unsqueeze(0), k_f.T) * scale # (1, N)
probs = torch.softmax(scores, dim=-1) # (1, N)
o_matmul = torch.matmul(probs, v_f) # (1, hd)
cos_matmul = torch.nn.functional.cosine_similarity(
o_4d[0, 0].float().flatten().unsqueeze(0),
o_matmul.flatten().unsqueeze(0),
).item()
print(f"Kernel vs expected: cos={cos:.6f}")
print(f"Kernel vs matmul: cos={cos_matmul:.6f}")
print(f"Kernel output[0:5]: {o_4d[0, 0, 0, 0:5].float()}")
print(f"Expected[0:5]: {o_expected[0:5]}")
print(f"Matmul[0:5]: {o_matmul[0, 0:5]}")
if __name__ == "__main__":
test_single_head_sanity()