50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
"""Quick test: n_h=1 regression after grid changes."""
|
|
import torch
|
|
import math
|
|
import cutlass.cute as cute
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
|
|
|
|
def test():
|
|
print("=== n_h=1 regression (hd=64, s_k=128) ===")
|
|
torch.manual_seed(42)
|
|
M, s_k, hd = 128, 128, 64
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
q = torch.randn(M, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
|
o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o))
|
|
fmha(q_c, k_c, v_c, o_c, stream)
|
|
|
|
# Reference
|
|
scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale
|
|
max_s = scores.max(dim=-1, keepdim=True).values
|
|
exp_s = (scores - max_s).exp()
|
|
sum_s = exp_s.sum(dim=-1, keepdim=True)
|
|
p = exp_s / sum_s
|
|
ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o[:,:,0].flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
|
|
).item()
|
|
print(f" cos = {cos:.6f}")
|
|
if cos >= 0.99:
|
|
print(" ✅ PASS")
|
|
else:
|
|
print(f" ❌ FAIL")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|