From db353ec35ab54bb18f60cc164cfc37ad063175df Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 22:39:25 +0000 Subject: [PATCH] D2: add simple n_h=1 regression test --- tests/unit/test_d2_regression.py | 60 ++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/unit/test_d2_regression.py diff --git a/tests/unit/test_d2_regression.py b/tests/unit/test_d2_regression.py new file mode 100644 index 00000000..acd67453 --- /dev/null +++ b/tests/unit/test_d2_regression.py @@ -0,0 +1,60 @@ +""" +D2: n_h=1 regression test. Identical to D1 but with the D2 test infrastructure. +""" +import torch, math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def test(): + hd = 64 + s_k = 128 + m = 128 + torch.manual_seed(42) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + + # FP32 reference + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_norm = (attn_exp / attn_sum) @ v.float() + + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False) + pv_n_tile = kernel.pv_n_tile + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + v_tile = v[:, 0:pv_n_tile].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + print(f'Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + compiled(mQ, mK, mV, mC, stream, mLSE) + torch.cuda.synchronize() + + out = c_tile[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0) + ).item() + print(f'hd={hd}, s_k={s_k}: cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') + + +if __name__ == '__main__': + test()