""" D2: n_h=1 regression test. Identical to D1 but with the D2 test infrastructure. """ import torch, math import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel def test(): hd = 64 s_k = 128 m = 128 torch.manual_seed(42) q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') # FP32 reference qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] attn_exp = torch.exp(qf @ kf.T * scale - attn_max) attn_sum = attn_exp.sum(dim=-1, keepdim=True) ref_unnorm = attn_exp @ v.float() ref_norm = (attn_exp / attn_sum) @ v.float() lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False) pv_n_tile = kernel.pv_n_tile stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) v_tile = v[:, 0:pv_n_tile].contiguous() v_kernel = v_tile.unsqueeze(-1) c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) print(f'Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) compiled(mQ, mK, mV, mC, stream, mLSE) torch.cuda.synchronize() out = c_tile[:, :, 0].float() cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) ).item() print(f'hd={hd}, s_k={s_k}: cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') if __name__ == '__main__': test()