104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
"""
|
|
Debug test: call fmha_multihead_decode_raw directly with production-style V.
|
|
Isolates whether the issue is in the V transpose or the production.py plumbing.
|
|
"""
|
|
import torch
|
|
import math
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
|
|
|
|
|
|
def cosine_sim(a, b):
|
|
a = a.flatten().float()
|
|
b = b.flatten().float()
|
|
return (a @ b) / (a.norm() * b.norm() + 1e-30)
|
|
|
|
|
|
def test_production_v_layout():
|
|
"""Test with V created as (N, hd) then transposed (production path)."""
|
|
torch.manual_seed(42)
|
|
|
|
hd = 64
|
|
n_h = 4
|
|
N = 128
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
# Create Q, K in the same way as both the working test and production
|
|
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
# V: production path creates (n_kv, N, hd) then transposes to (1, n_kv, hd, N)
|
|
v_orig = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v_4d = v_orig.unsqueeze(0).transpose(-1, -2).contiguous()
|
|
print(f"V orig shape: {v_orig.shape}")
|
|
print(f"V 4d shape: {v_4d.shape}, strides: {v_4d.stride()}")
|
|
|
|
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
|
|
o_4d, lse_4d = fmha_multihead_decode_raw(q_4d, k_4d, v_4d, scale, 0, 0, False, sb)
|
|
|
|
# Reference: use v_orig (N, hd) per head
|
|
q_ref = q_4d[0] # (n_h, 1, hd)
|
|
k_ref = k_4d[0] # (n_h, N, hd)
|
|
|
|
for h in range(n_h):
|
|
q_h = q_ref[h] # (1, hd)
|
|
k_h = k_ref[h] # (N, hd)
|
|
v_h = v_orig[h] # (N, hd)
|
|
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
|
s = torch.softmax(s, dim=-1)
|
|
o_ref = torch.matmul(s, v_h.float())
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o_4d[0, h].float().flatten().unsqueeze(0),
|
|
o_ref.flatten().unsqueeze(0),
|
|
).item()
|
|
print(f" Head {h}: cos={cos:.6f}")
|
|
|
|
|
|
def test_native_v_layout():
|
|
"""Test with V created as (hd, N) natively (working test style)."""
|
|
torch.manual_seed(42)
|
|
|
|
hd = 64
|
|
n_h = 4
|
|
N = 128
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
q_4d = torch.randn(1, n_h, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
k_4d = torch.randn(1, n_h, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
v_4d = torch.randn(1, n_h, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
sb = torch.zeros(1, n_h, dtype=torch.float32, device='cuda')
|
|
o_4d, lse_4d = fmha_multihead_decode_raw(q_4d, k_4d, v_4d, scale, 0, 0, False, sb)
|
|
|
|
# Reference: V is (hd, N) per head, transpose to (N, hd) for reference
|
|
v_ref = v_4d[0].transpose(-1, -2) # (n_h, N, hd)
|
|
q_ref = q_4d[0]
|
|
k_ref = k_4d[0]
|
|
|
|
for h in range(n_h):
|
|
q_h = q_ref[h]
|
|
k_h = k_ref[h]
|
|
v_h = v_ref[h] # (N, hd)
|
|
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
|
s = torch.softmax(s, dim=-1)
|
|
o_ref = torch.matmul(s, v_h.float())
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o_4d[0, h].float().flatten().unsqueeze(0),
|
|
o_ref.flatten().unsqueeze(0),
|
|
).item()
|
|
print(f" Head {h}: cos={cos:.6f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("=== Test 1: V created as (N,hd) then transposed (production path) ===")
|
|
test_production_v_layout()
|
|
print()
|
|
print("=== Test 2: V created natively as (hd,N) (working test style) ===")
|
|
test_native_v_layout()
|