diff --git a/tests/unit/test_p3_v_ref_debug.py b/tests/unit/test_p3_v_ref_debug.py new file mode 100644 index 00000000..072701f3 --- /dev/null +++ b/tests/unit/test_p3_v_ref_debug.py @@ -0,0 +1,91 @@ +""" +Debug: why does the full API test give cos=0.83? +Test 1: V in kernel layout (hd, N), reference transposes -> (N, hd) +Test 2: V in standard layout (N, hd), reference uses directly +Both should give same result if math is correct. +""" +import torch +import math +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw + + +def test_v_layout_comparison(): + """Direct comparison: same Q and K, V in two different layouts.""" + torch.manual_seed(42) + + hd = 64 + n_h = 4 + N = 128 + scale = 1.0 / math.sqrt(hd) + + # Create Q and K once + q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous() + k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous() + + # Create V as (n_h, hd, N) natively + v_native = torch.randn(1, n_h, hd, N, dtype=torch.bfloat16, device='cuda').contiguous() + + # Also create V as (n_h, N, hd) then transpose + v_orig = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda') + v_transposed = v_orig.unsqueeze(0).transpose(-1, -2).contiguous() # (1, n_h, hd, N) + + # Run kernel with native V + sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda') + o_native, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_native, scale, 0, 0, False, sb) + + # Run kernel with transposed-from-standard V + o_transposed, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_transposed, scale, 0, 0, False, sb) + + # Reference with native V (hd, N) -> transpose to (N, hd) + q_ref = q_4d[0] # (n_h, 1, hd) + k_ref = k_4d[0] # (n_h, N, hd) + v_ref_native = v_native[0].transpose(-1, -2) # (n_h, N, hd) — transposed from (hd, N) + v_ref_orig = v_orig # (n_h, N, hd) — already in (N, hd) layout + + # Reference 1: using native V data + o_ref1 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda') + for h in range(n_h): + q_h = q_ref[h] # (1, hd) + k_h = k_ref[h] # (N, hd) + v_h = v_ref_native[h] # (N, hd) + s = torch.matmul(q_h.float(), k_h.float().T) * scale + s = torch.softmax(s, dim=-1) + o = torch.matmul(s, v_h.float()) + o_ref1[h] = o.bfloat16() + + # Reference 2: using original V data + o_ref2 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda') + for h in range(n_h): + q_h = q_ref[h] + k_h = k_ref[h] + v_h = v_ref_orig[h] # (N, hd) — same data, different source + s = torch.matmul(q_h.float(), k_h.float().T) * scale + s = torch.softmax(s, dim=-1) + o = torch.matmul(s, v_h.float()) + o_ref2[h] = o.bfloat16() + + # Compare kernel vs ref1 (native V) + for h in range(n_h): + cos1 = torch.nn.functional.cosine_similarity( + o_native[0, h].float().flatten().unsqueeze(0), + o_ref1[h].float().flatten().unsqueeze(0), + ).item() + cos2 = torch.nn.functional.cosine_similarity( + o_transposed[0, h].float().flatten().unsqueeze(0), + o_ref2[h].float().flatten().unsqueeze(0), + ).item() + # Also compare the two kernel outputs (should differ since different V data) + cos_kk = torch.nn.functional.cosine_similarity( + o_native[0, h].float().flatten().unsqueeze(0), + o_transposed[0, h].float().flatten().unsqueeze(0), + ).item() + print(f" Head {h}: native_vs_ref1={cos1:.6f} transposed_vs_ref2={cos2:.6f} native_vs_transposed={cos_kk:.6f}") + + +if __name__ == "__main__": + test_v_layout_comparison()