debug: single head sanity test with known values
This commit is contained in:
67
tests/unit/test_p3_sanity.py
Normal file
67
tests/unit/test_p3_sanity.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Absolute simplest test: single head, small N, verify kernel == reference.
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
|
||||
|
||||
|
||||
def test_single_head_sanity():
|
||||
"""Single head, N=128, hd=64. Known values, no randomness."""
|
||||
hd = 64
|
||||
N = 128
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
# Q: (1, 1, 1, hd) — single head, single query token
|
||||
q = torch.ones(1, 1, 1, hd, dtype=torch.bfloat16, device='cuda')
|
||||
# K: (1, 1, N, hd) — single KV head, N positions
|
||||
k = torch.ones(1, 1, N, hd, dtype=torch.bfloat16, device='cuda')
|
||||
# V: (1, 1, hd, N) — in kernel layout
|
||||
# Let's make V[d, r] = d + r*0.01 (simple pattern)
|
||||
v_data = torch.arange(hd, dtype=torch.float32, device='cuda').unsqueeze(1) + \
|
||||
torch.arange(N, dtype=torch.float32, device='cuda').unsqueeze(0) * 0.01
|
||||
v_4d = v_data.bfloat16().unsqueeze(0).unsqueeze(0) # (1, 1, hd, N)
|
||||
|
||||
sb = torch.zeros(1, 1, dtype=torch.float32, device='cuda')
|
||||
o_4d, lse = fmha_multihead_decode_raw(q, k, v_4d, scale, 0, 0, False, sb)
|
||||
|
||||
# Reference: Q is all-ones, K is all-ones, so QK^T gives all-equal scores
|
||||
# softmax of uniform = 1/N. So O = (1/N) * sum(V[r, d] for r in 0..N-1)
|
||||
v_ref = v_data.T # (N, hd) — reference uses (N, hd) layout
|
||||
# Each V[r, d] = d + r*0.01
|
||||
# sum over r: sum(d + r*0.01) = N*d + 0.01*sum(r) = N*d + 0.01*N*(N-1)/2
|
||||
# O[d] = (1/N) * (N*d + 0.01*N*(N-1)/2) = d + 0.01*(N-1)/2
|
||||
o_expected = torch.arange(hd, dtype=torch.float32, device='cuda') + 0.01 * (N - 1) / 2
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_4d[0, 0].float().flatten().unsqueeze(0),
|
||||
o_expected.flatten().unsqueeze(0),
|
||||
).item()
|
||||
|
||||
# Also compute via direct matmul for sanity
|
||||
q_f = q.float().squeeze() # (hd,) all ones
|
||||
k_f = k.float().squeeze() # (N, hd) all ones
|
||||
v_f = v_ref # (N, hd)
|
||||
scores = torch.matmul(q_f.unsqueeze(0), k_f.T) * scale # (1, N)
|
||||
probs = torch.softmax(scores, dim=-1) # (1, N)
|
||||
o_matmul = torch.matmul(probs, v_f) # (1, hd)
|
||||
|
||||
cos_matmul = torch.nn.functional.cosine_similarity(
|
||||
o_4d[0, 0].float().flatten().unsqueeze(0),
|
||||
o_matmul.flatten().unsqueeze(0),
|
||||
).item()
|
||||
|
||||
print(f"Kernel vs expected: cos={cos:.6f}")
|
||||
print(f"Kernel vs matmul: cos={cos_matmul:.6f}")
|
||||
print(f"Kernel output[0:5]: {o_4d[0, 0, 0, 0:5].float()}")
|
||||
print(f"Expected[0:5]: {o_expected[0:5]}")
|
||||
print(f"Matmul[0:5]: {o_matmul[0, 0:5]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_single_head_sanity()
|
||||
Reference in New Issue
Block a user