debug: V layout reference comparison
This commit is contained in:
91
tests/unit/test_p3_v_ref_debug.py
Normal file
91
tests/unit/test_p3_v_ref_debug.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Debug: why does the full API test give cos=0.83?
|
||||
Test 1: V in kernel layout (hd, N), reference transposes -> (N, hd)
|
||||
Test 2: V in standard layout (N, hd), reference uses directly
|
||||
Both should give same result if math is correct.
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
|
||||
|
||||
|
||||
def test_v_layout_comparison():
|
||||
"""Direct comparison: same Q and K, V in two different layouts."""
|
||||
torch.manual_seed(42)
|
||||
|
||||
hd = 64
|
||||
n_h = 4
|
||||
N = 128
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
# Create Q and K once
|
||||
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
|
||||
# Create V as (n_h, hd, N) natively
|
||||
v_native = torch.randn(1, n_h, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
|
||||
# Also create V as (n_h, N, hd) then transpose
|
||||
v_orig = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v_transposed = v_orig.unsqueeze(0).transpose(-1, -2).contiguous() # (1, n_h, hd, N)
|
||||
|
||||
# Run kernel with native V
|
||||
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
|
||||
o_native, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_native, scale, 0, 0, False, sb)
|
||||
|
||||
# Run kernel with transposed-from-standard V
|
||||
o_transposed, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_transposed, scale, 0, 0, False, sb)
|
||||
|
||||
# Reference with native V (hd, N) -> transpose to (N, hd)
|
||||
q_ref = q_4d[0] # (n_h, 1, hd)
|
||||
k_ref = k_4d[0] # (n_h, N, hd)
|
||||
v_ref_native = v_native[0].transpose(-1, -2) # (n_h, N, hd) — transposed from (hd, N)
|
||||
v_ref_orig = v_orig # (n_h, N, hd) — already in (N, hd) layout
|
||||
|
||||
# Reference 1: using native V data
|
||||
o_ref1 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
|
||||
for h in range(n_h):
|
||||
q_h = q_ref[h] # (1, hd)
|
||||
k_h = k_ref[h] # (N, hd)
|
||||
v_h = v_ref_native[h] # (N, hd)
|
||||
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
||||
s = torch.softmax(s, dim=-1)
|
||||
o = torch.matmul(s, v_h.float())
|
||||
o_ref1[h] = o.bfloat16()
|
||||
|
||||
# Reference 2: using original V data
|
||||
o_ref2 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
|
||||
for h in range(n_h):
|
||||
q_h = q_ref[h]
|
||||
k_h = k_ref[h]
|
||||
v_h = v_ref_orig[h] # (N, hd) — same data, different source
|
||||
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
||||
s = torch.softmax(s, dim=-1)
|
||||
o = torch.matmul(s, v_h.float())
|
||||
o_ref2[h] = o.bfloat16()
|
||||
|
||||
# Compare kernel vs ref1 (native V)
|
||||
for h in range(n_h):
|
||||
cos1 = torch.nn.functional.cosine_similarity(
|
||||
o_native[0, h].float().flatten().unsqueeze(0),
|
||||
o_ref1[h].float().flatten().unsqueeze(0),
|
||||
).item()
|
||||
cos2 = torch.nn.functional.cosine_similarity(
|
||||
o_transposed[0, h].float().flatten().unsqueeze(0),
|
||||
o_ref2[h].float().flatten().unsqueeze(0),
|
||||
).item()
|
||||
# Also compare the two kernel outputs (should differ since different V data)
|
||||
cos_kk = torch.nn.functional.cosine_similarity(
|
||||
o_native[0, h].float().flatten().unsqueeze(0),
|
||||
o_transposed[0, h].float().flatten().unsqueeze(0),
|
||||
).item()
|
||||
print(f" Head {h}: native_vs_ref1={cos1:.6f} transposed_vs_ref2={cos2:.6f} native_vs_transposed={cos_kk:.6f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_v_layout_comparison()
|
||||
Reference in New Issue
Block a user