From aa66f44ff9545f1dd5ffc7fd8c7b3f2335e90d04 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 25 May 2026 17:00:56 +0000 Subject: [PATCH] add n_h=1 regression test --- tests/unit/test_d2_regression.py | 80 ++++++++++++++------------------ 1 file changed, 34 insertions(+), 46 deletions(-) diff --git a/tests/unit/test_d2_regression.py b/tests/unit/test_d2_regression.py index 09f78706..af9ad33e 100644 --- a/tests/unit/test_d2_regression.py +++ b/tests/unit/test_d2_regression.py @@ -1,60 +1,48 @@ -""" -D2: n_h=1 regression test. Identical to D1 but with the D2 test infrastructure. -""" -import torch, math +"""Quick test: n_h=1 regression after grid changes.""" +import torch +import math import cutlass.cute as cute -import cutlass.torch as ct import cuda.bindings.driver as cuda +import cutlass.torch as ct from dsv4.kernels.attention.fmha import FmhaKernel def test(): - hd = 64 - s_k = 128 - m = 128 + print("=== n_h=1 regression (hd=64, s_k=128) ===") torch.manual_seed(42) - - q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') - k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') - v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') - c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') - - # FP32 reference - qf = q[:, :, 0].float() - kf = k[:, :, 0].float() + M, s_k, hd = 128, 128, 64 scale = 1.0 / math.sqrt(hd) - attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] - attn_exp = torch.exp(qf @ kf.T * scale - attn_max) - attn_sum = attn_exp.sum(dim=-1, keepdim=True) - ref_unnorm = attn_exp @ v.float() - ref_norm = (attn_exp / attn_sum) @ v.float() - - lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') - - kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False) - pv_n_tile = kernel.pv_n_tile + + q = torch.randn(M, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda') + + fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - v_tile = v[:, 0:pv_n_tile].contiguous() - v_kernel = v_tile.unsqueeze(-1) - c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') - - mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) - mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) - mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) - mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) - mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) - - print(f'Compiling...', flush=True) - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) - compiled(mQ, mK, mV, mC, stream, mLSE) - torch.cuda.synchronize() - - out = c_tile[:, :, 0].float() + + q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o)) + fmha(q_c, k_c, v_c, o_c, stream) + + # Reference + scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale + max_s = scores.max(dim=-1, keepdim=True).values + exp_s = (scores - max_s).exp() + sum_s = exp_s.sum(dim=-1, keepdim=True) + p = exp_s / sum_s + ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16) + cos = torch.nn.functional.cosine_similarity( - out.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) + o[:,:,0].flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) ).item() - print(f'hd={hd}, s_k={s_k}: cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') + print(f" cos = {cos:.6f}") + if cos >= 0.99: + print(" ✅ PASS") + else: + print(f" ❌ FAIL") if __name__ == '__main__':