D2: add simple n_h=1 regression test

This commit is contained in:
2026-05-24 22:39:25 +00:00
parent 4418e04a28
commit db353ec35a

View File

@@ -0,0 +1,60 @@
"""
D2: n_h=1 regression test. Identical to D1 but with the D2 test infrastructure.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
hd = 64
s_k = 128
m = 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# FP32 reference
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_norm = (attn_exp / attn_sum) @ v.float()
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False)
pv_n_tile = kernel.pv_n_tile
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
print(f'Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
out = c_tile[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)
).item()
print(f'hd={hd}, s_k={s_k}: cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
if __name__ == '__main__':
test()