92 lines
3.5 KiB
Python
92 lines
3.5 KiB
Python
"""
|
|
Debug: why does the full API test give cos=0.83?
|
|
Test 1: V in kernel layout (hd, N), reference transposes -> (N, hd)
|
|
Test 2: V in standard layout (N, hd), reference uses directly
|
|
Both should give same result if math is correct.
|
|
"""
|
|
import torch
|
|
import math
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
|
|
|
|
|
|
def test_v_layout_comparison():
|
|
"""Direct comparison: same Q and K, V in two different layouts."""
|
|
torch.manual_seed(42)
|
|
|
|
hd = 64
|
|
n_h = 4
|
|
N = 128
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
# Create Q and K once
|
|
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
# Create V as (n_h, hd, N) natively
|
|
v_native = torch.randn(1, n_h, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
# Also create V as (n_h, N, hd) then transpose
|
|
v_orig = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v_transposed = v_orig.unsqueeze(0).transpose(-1, -2).contiguous() # (1, n_h, hd, N)
|
|
|
|
# Run kernel with native V
|
|
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
|
|
o_native, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_native, scale, 0, 0, False, sb)
|
|
|
|
# Run kernel with transposed-from-standard V
|
|
o_transposed, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_transposed, scale, 0, 0, False, sb)
|
|
|
|
# Reference with native V (hd, N) -> transpose to (N, hd)
|
|
q_ref = q_4d[0] # (n_h, 1, hd)
|
|
k_ref = k_4d[0] # (n_h, N, hd)
|
|
v_ref_native = v_native[0].transpose(-1, -2) # (n_h, N, hd) — transposed from (hd, N)
|
|
v_ref_orig = v_orig # (n_h, N, hd) — already in (N, hd) layout
|
|
|
|
# Reference 1: using native V data
|
|
o_ref1 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
|
|
for h in range(n_h):
|
|
q_h = q_ref[h] # (1, hd)
|
|
k_h = k_ref[h] # (N, hd)
|
|
v_h = v_ref_native[h] # (N, hd)
|
|
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
|
s = torch.softmax(s, dim=-1)
|
|
o = torch.matmul(s, v_h.float())
|
|
o_ref1[h] = o.bfloat16()
|
|
|
|
# Reference 2: using original V data
|
|
o_ref2 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
|
|
for h in range(n_h):
|
|
q_h = q_ref[h]
|
|
k_h = k_ref[h]
|
|
v_h = v_ref_orig[h] # (N, hd) — same data, different source
|
|
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
|
s = torch.softmax(s, dim=-1)
|
|
o = torch.matmul(s, v_h.float())
|
|
o_ref2[h] = o.bfloat16()
|
|
|
|
# Compare kernel vs ref1 (native V)
|
|
for h in range(n_h):
|
|
cos1 = torch.nn.functional.cosine_similarity(
|
|
o_native[0, h].float().flatten().unsqueeze(0),
|
|
o_ref1[h].float().flatten().unsqueeze(0),
|
|
).item()
|
|
cos2 = torch.nn.functional.cosine_similarity(
|
|
o_transposed[0, h].float().flatten().unsqueeze(0),
|
|
o_ref2[h].float().flatten().unsqueeze(0),
|
|
).item()
|
|
# Also compare the two kernel outputs (should differ since different V data)
|
|
cos_kk = torch.nn.functional.cosine_similarity(
|
|
o_native[0, h].float().flatten().unsqueeze(0),
|
|
o_transposed[0, h].float().flatten().unsqueeze(0),
|
|
).item()
|
|
print(f" Head {h}: native_vs_ref1={cos1:.6f} transposed_vs_ref2={cos2:.6f} native_vs_transposed={cos_kk:.6f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_v_layout_comparison()
|