debug: V layout comparison test
This commit is contained in:
103
tests/unit/test_p3_v_debug.py
Normal file
103
tests/unit/test_p3_v_debug.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Debug test: call fmha_multihead_decode_raw directly with production-style V.
|
||||
Isolates whether the issue is in the V transpose or the production.py plumbing.
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
|
||||
|
||||
|
||||
def cosine_sim(a, b):
|
||||
a = a.flatten().float()
|
||||
b = b.flatten().float()
|
||||
return (a @ b) / (a.norm() * b.norm() + 1e-30)
|
||||
|
||||
|
||||
def test_production_v_layout():
|
||||
"""Test with V created as (N, hd) then transposed (production path)."""
|
||||
torch.manual_seed(42)
|
||||
|
||||
hd = 64
|
||||
n_h = 4
|
||||
N = 128
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
# Create Q, K in the same way as both the working test and production
|
||||
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
|
||||
# V: production path creates (n_kv, N, hd) then transposes to (1, n_kv, hd, N)
|
||||
v_orig = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v_4d = v_orig.unsqueeze(0).transpose(-1, -2).contiguous()
|
||||
print(f"V orig shape: {v_orig.shape}")
|
||||
print(f"V 4d shape: {v_4d.shape}, strides: {v_4d.stride()}")
|
||||
|
||||
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
|
||||
o_4d, lse_4d = fmha_multihead_decode_raw(q_4d, k_4d, v_4d, scale, 0, 0, False, sb)
|
||||
|
||||
# Reference: use v_orig (N, hd) per head
|
||||
q_ref = q_4d[0] # (n_h, 1, hd)
|
||||
k_ref = k_4d[0] # (n_h, N, hd)
|
||||
|
||||
for h in range(n_h):
|
||||
q_h = q_ref[h] # (1, hd)
|
||||
k_h = k_ref[h] # (N, hd)
|
||||
v_h = v_orig[h] # (N, hd)
|
||||
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
||||
s = torch.softmax(s, dim=-1)
|
||||
o_ref = torch.matmul(s, v_h.float())
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_4d[0, h].float().flatten().unsqueeze(0),
|
||||
o_ref.flatten().unsqueeze(0),
|
||||
).item()
|
||||
print(f" Head {h}: cos={cos:.6f}")
|
||||
|
||||
|
||||
def test_native_v_layout():
|
||||
"""Test with V created as (hd, N) natively (working test style)."""
|
||||
torch.manual_seed(42)
|
||||
|
||||
hd = 64
|
||||
n_h = 4
|
||||
N = 128
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
v_4d = torch.randn(1, n_h, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
|
||||
|
||||
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
|
||||
o_4d, lse_4d = fmha_multihead_decode_raw(q_4d, k_4d, v_4d, scale, 0, 0, False, sb)
|
||||
|
||||
# Reference: V is (hd, N) per head, transpose to (N, hd) for reference
|
||||
v_ref = v_4d[0].transpose(-1, -2) # (n_h, N, hd)
|
||||
q_ref = q_4d[0]
|
||||
k_ref = k_4d[0]
|
||||
|
||||
for h in range(n_h):
|
||||
q_h = q_ref[h]
|
||||
k_h = k_ref[h]
|
||||
v_h = v_ref[h] # (N, hd)
|
||||
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
||||
s = torch.softmax(s, dim=-1)
|
||||
o_ref = torch.matmul(s, v_h.float())
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_4d[0, h].float().flatten().unsqueeze(0),
|
||||
o_ref.flatten().unsqueeze(0),
|
||||
).item()
|
||||
print(f" Head {h}: cos={cos:.6f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== Test 1: V created as (N,hd) then transposed (production path) ===")
|
||||
test_production_v_layout()
|
||||
print()
|
||||
print("=== Test 2: V created natively as (hd,N) (working test style) ===")
|
||||
test_native_v_layout()
|
||||
Reference in New Issue
Block a user