Files
nvfp4-megamoe-kernel/tests/unit/test_p3_v_ref_debug.py

92 lines
3.5 KiB
Python

"""
Debug: why does the full API test give cos=0.83?
Test 1: V in kernel layout (hd, N), reference transposes -> (N, hd)
Test 2: V in standard layout (N, hd), reference uses directly
Both should give same result if math is correct.
"""
import torch
import math
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
def test_v_layout_comparison():
"""Direct comparison: same Q and K, V in two different layouts."""
torch.manual_seed(42)
hd = 64
n_h = 4
N = 128
scale = 1.0 / math.sqrt(hd)
# Create Q and K once
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
# Create V as (n_h, hd, N) natively
v_native = torch.randn(1, n_h, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
# Also create V as (n_h, N, hd) then transpose
v_orig = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
v_transposed = v_orig.unsqueeze(0).transpose(-1, -2).contiguous() # (1, n_h, hd, N)
# Run kernel with native V
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
o_native, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_native, scale, 0, 0, False, sb)
# Run kernel with transposed-from-standard V
o_transposed, _ = fmha_multihead_decode_raw(q_4d, k_4d, v_transposed, scale, 0, 0, False, sb)
# Reference with native V (hd, N) -> transpose to (N, hd)
q_ref = q_4d[0] # (n_h, 1, hd)
k_ref = k_4d[0] # (n_h, N, hd)
v_ref_native = v_native[0].transpose(-1, -2) # (n_h, N, hd) — transposed from (hd, N)
v_ref_orig = v_orig # (n_h, N, hd) — already in (N, hd) layout
# Reference 1: using native V data
o_ref1 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
for h in range(n_h):
q_h = q_ref[h] # (1, hd)
k_h = k_ref[h] # (N, hd)
v_h = v_ref_native[h] # (N, hd)
s = torch.matmul(q_h.float(), k_h.float().T) * scale
s = torch.softmax(s, dim=-1)
o = torch.matmul(s, v_h.float())
o_ref1[h] = o.bfloat16()
# Reference 2: using original V data
o_ref2 = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device='cuda')
for h in range(n_h):
q_h = q_ref[h]
k_h = k_ref[h]
v_h = v_ref_orig[h] # (N, hd) — same data, different source
s = torch.matmul(q_h.float(), k_h.float().T) * scale
s = torch.softmax(s, dim=-1)
o = torch.matmul(s, v_h.float())
o_ref2[h] = o.bfloat16()
# Compare kernel vs ref1 (native V)
for h in range(n_h):
cos1 = torch.nn.functional.cosine_similarity(
o_native[0, h].float().flatten().unsqueeze(0),
o_ref1[h].float().flatten().unsqueeze(0),
).item()
cos2 = torch.nn.functional.cosine_similarity(
o_transposed[0, h].float().flatten().unsqueeze(0),
o_ref2[h].float().flatten().unsqueeze(0),
).item()
# Also compare the two kernel outputs (should differ since different V data)
cos_kk = torch.nn.functional.cosine_similarity(
o_native[0, h].float().flatten().unsqueeze(0),
o_transposed[0, h].float().flatten().unsqueeze(0),
).item()
print(f" Head {h}: native_vs_ref1={cos1:.6f} transposed_vs_ref2={cos2:.6f} native_vs_transposed={cos_kk:.6f}")
if __name__ == "__main__":
test_v_layout_comparison()