debug: V layout comparison test

This commit is contained in:
2026-05-30 08:22:49 +00:00
parent 074c4c4f42
commit 78e6d58b85

View File

@@ -0,0 +1,103 @@
"""
Debug test: call fmha_multihead_decode_raw directly with production-style V.
Isolates whether the issue is in the V transpose or the production.py plumbing.
"""
import torch
import math
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
def cosine_sim(a, b):
a = a.flatten().float()
b = b.flatten().float()
return (a @ b) / (a.norm() * b.norm() + 1e-30)
def test_production_v_layout():
"""Test with V created as (N, hd) then transposed (production path)."""
torch.manual_seed(42)
hd = 64
n_h = 4
N = 128
scale = 1.0 / math.sqrt(hd)
# Create Q, K in the same way as both the working test and production
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
# V: production path creates (n_kv, N, hd) then transposes to (1, n_kv, hd, N)
v_orig = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
v_4d = v_orig.unsqueeze(0).transpose(-1, -2).contiguous()
print(f"V orig shape: {v_orig.shape}")
print(f"V 4d shape: {v_4d.shape}, strides: {v_4d.stride()}")
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
o_4d, lse_4d = fmha_multihead_decode_raw(q_4d, k_4d, v_4d, scale, 0, 0, False, sb)
# Reference: use v_orig (N, hd) per head
q_ref = q_4d[0] # (n_h, 1, hd)
k_ref = k_4d[0] # (n_h, N, hd)
for h in range(n_h):
q_h = q_ref[h] # (1, hd)
k_h = k_ref[h] # (N, hd)
v_h = v_orig[h] # (N, hd)
s = torch.matmul(q_h.float(), k_h.float().T) * scale
s = torch.softmax(s, dim=-1)
o_ref = torch.matmul(s, v_h.float())
cos = torch.nn.functional.cosine_similarity(
o_4d[0, h].float().flatten().unsqueeze(0),
o_ref.flatten().unsqueeze(0),
).item()
print(f" Head {h}: cos={cos:.6f}")
def test_native_v_layout():
"""Test with V created as (hd, N) natively (working test style)."""
torch.manual_seed(42)
hd = 64
n_h = 4
N = 128
scale = 1.0 / math.sqrt(hd)
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
v_4d = torch.randn(1, n_h, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
o_4d, lse_4d = fmha_multihead_decode_raw(q_4d, k_4d, v_4d, scale, 0, 0, False, sb)
# Reference: V is (hd, N) per head, transpose to (N, hd) for reference
v_ref = v_4d[0].transpose(-1, -2) # (n_h, N, hd)
q_ref = q_4d[0]
k_ref = k_4d[0]
for h in range(n_h):
q_h = q_ref[h]
k_h = k_ref[h]
v_h = v_ref[h] # (N, hd)
s = torch.matmul(q_h.float(), k_h.float().T) * scale
s = torch.softmax(s, dim=-1)
o_ref = torch.matmul(s, v_h.float())
cos = torch.nn.functional.cosine_similarity(
o_4d[0, h].float().flatten().unsqueeze(0),
o_ref.flatten().unsqueeze(0),
).item()
print(f" Head {h}: cos={cos:.6f}")
if __name__ == "__main__":
print("=== Test 1: V created as (N,hd) then transposed (production path) ===")
test_production_v_layout()
print()
print("=== Test 2: V created natively as (hd,N) (working test style) ===")
test_native_v_layout()