From cfac224b59b06628879438e2e0e46f48a264eda2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 08:25:20 +0000 Subject: [PATCH] debug: single head sanity test with known values --- tests/unit/test_p3_sanity.py | 67 ++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/unit/test_p3_sanity.py diff --git a/tests/unit/test_p3_sanity.py b/tests/unit/test_p3_sanity.py new file mode 100644 index 00000000..46b2843e --- /dev/null +++ b/tests/unit/test_p3_sanity.py @@ -0,0 +1,67 @@ +""" +Absolute simplest test: single head, small N, verify kernel == reference. +""" +import torch +import math +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw + + +def test_single_head_sanity(): + """Single head, N=128, hd=64. Known values, no randomness.""" + hd = 64 + N = 128 + scale = 1.0 / math.sqrt(hd) + + # Q: (1, 1, 1, hd) — single head, single query token + q = torch.ones(1, 1, 1, hd, dtype=torch.bfloat16, device='cuda') + # K: (1, 1, N, hd) — single KV head, N positions + k = torch.ones(1, 1, N, hd, dtype=torch.bfloat16, device='cuda') + # V: (1, 1, hd, N) — in kernel layout + # Let's make V[d, r] = d + r*0.01 (simple pattern) + v_data = torch.arange(hd, dtype=torch.float32, device='cuda').unsqueeze(1) + \ + torch.arange(N, dtype=torch.float32, device='cuda').unsqueeze(0) * 0.01 + v_4d = v_data.bfloat16().unsqueeze(0).unsqueeze(0) # (1, 1, hd, N) + + sb = torch.zeros(1, 1, dtype=torch.float32, device='cuda') + o_4d, lse = fmha_multihead_decode_raw(q, k, v_4d, scale, 0, 0, False, sb) + + # Reference: Q is all-ones, K is all-ones, so QK^T gives all-equal scores + # softmax of uniform = 1/N. So O = (1/N) * sum(V[r, d] for r in 0..N-1) + v_ref = v_data.T # (N, hd) — reference uses (N, hd) layout + # Each V[r, d] = d + r*0.01 + # sum over r: sum(d + r*0.01) = N*d + 0.01*sum(r) = N*d + 0.01*N*(N-1)/2 + # O[d] = (1/N) * (N*d + 0.01*N*(N-1)/2) = d + 0.01*(N-1)/2 + o_expected = torch.arange(hd, dtype=torch.float32, device='cuda') + 0.01 * (N - 1) / 2 + + cos = torch.nn.functional.cosine_similarity( + o_4d[0, 0].float().flatten().unsqueeze(0), + o_expected.flatten().unsqueeze(0), + ).item() + + # Also compute via direct matmul for sanity + q_f = q.float().squeeze() # (hd,) all ones + k_f = k.float().squeeze() # (N, hd) all ones + v_f = v_ref # (N, hd) + scores = torch.matmul(q_f.unsqueeze(0), k_f.T) * scale # (1, N) + probs = torch.softmax(scores, dim=-1) # (1, N) + o_matmul = torch.matmul(probs, v_f) # (1, hd) + + cos_matmul = torch.nn.functional.cosine_similarity( + o_4d[0, 0].float().flatten().unsqueeze(0), + o_matmul.flatten().unsqueeze(0), + ).item() + + print(f"Kernel vs expected: cos={cos:.6f}") + print(f"Kernel vs matmul: cos={cos_matmul:.6f}") + print(f"Kernel output[0:5]: {o_4d[0, 0, 0, 0:5].float()}") + print(f"Expected[0:5]: {o_expected[0:5]}") + print(f"Matmul[0:5]: {o_matmul[0, 0:5]}") + + +if __name__ == "__main__": + test_single_head_sanity()