"""Quick D1 regression test: HEAD_DIM=64 only, must match Stage C.""" import torch, math import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel def test(): torch.manual_seed(42) hd, n = 64, 128 m = 128 q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') v_kernel = v.unsqueeze(-1) c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn = qf @ kf.T * scale attn = torch.softmax(attn, dim=-1) ref = attn @ v.float() mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = FmhaKernel(head_dim=hd, s_k=n) print(f'hd={hd}, n={n}: Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) compiled(mQ, mK, mV, mC, stream) torch.cuda.synchronize() out = c[:, :, 0].float() cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) ).item() max_abs = (out - ref).abs().max().item() print(f'hd={hd}, n={n}: cos {cos:.6f} max_abs {max_abs:.4f} {"PASS" if cos >= 0.97 else "FAIL"}') if __name__ == '__main__': test()