diff --git a/tests/unit/test_d1_hd512.py b/tests/unit/test_d1_hd512.py new file mode 100644 index 00000000..afb5e83a --- /dev/null +++ b/tests/unit/test_d1_hd512.py @@ -0,0 +1,48 @@ +"""D1 test: HEAD_DIM=512 (DSV4 production config).""" +import torch, math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + +def test(): + torch.manual_seed(42) + hd, n = 512, 128 + m = 128 + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') + v_kernel = v.unsqueeze(-1) + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn = qf @ kf.T * scale + attn = torch.softmax(attn, dim=-1) + ref = attn @ v.float() + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = FmhaKernel(head_dim=hd, s_k=n) + print(f'hd={hd}, n={n}: Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + + out = c[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + ).item() + max_abs = (out - ref).abs().max().item() + print(f'hd={hd}, n={n}: cos {cos:.6f} max_abs {max_abs:.4f} {"PASS" if cos >= 0.97 else "FAIL"}') + if cos < 0.97: + print(f' out[0,:4]={out[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') + +if __name__ == '__main__': + test()