From 673825c24274c20e69a3b43fcc722cfb044c08b5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 25 May 2026 17:11:59 +0000 Subject: [PATCH] rewrite D2 regression test: match existing Stage D1 test pattern with cute.compile + PV tiles --- tests/unit/test_d2_regression.py | 186 +++++++++++++++++++++++++------ 1 file changed, 149 insertions(+), 37 deletions(-) diff --git a/tests/unit/test_d2_regression.py b/tests/unit/test_d2_regression.py index 938ee5dc..3ef9fb1e 100644 --- a/tests/unit/test_d2_regression.py +++ b/tests/unit/test_d2_regression.py @@ -1,54 +1,166 @@ -"""Quick test: n_h=1 regression after grid changes.""" +""" +FMHA D2 regression test (matches existing test pattern). + +Uses the same cute.compile + PV tile iteration as test_fmha_v3_stage_d1.py. + +Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_regression.py +""" import torch import math +import cutlass import cutlass.cute as cute -import cuda.bindings.driver as cuda import cutlass.torch as ct +from cutlass import Float32, BFloat16 +import cuda.bindings.driver as cuda + from dsv4.kernels.attention.fmha import FmhaKernel -def test(): - print("=== n_h=1 regression (hd=64, s_k=128) ===") - torch.manual_seed(42) - M, s_k, hd = 128, 128, 64 - scale = 1.0 / math.sqrt(hd) - - q = torch.randn(M, hd, 1, dtype=torch.bfloat16, device='cuda') - k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') - v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') - o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda') - - fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=False) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) - k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) - v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) - o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o)) - lse = torch.zeros(M, dtype=torch.float32, device='cuda') - lse_c = ct.from_dlpack(lse).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse)) - fmha(q_c, k_c, v_c, o_c, stream, lse_c) - - # External normalization using LSE - row_sum = lse.exp() - o_norm = o[:,:,0] / row_sum.unsqueeze(-1) - - # Reference - scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale +def reference_fmha(q, k, v, scale): + """FP32 reference: q (M, hd), k (s_k, hd), v (s_k, hd) → o (M, hd)""" + scores = torch.matmul(q.float(), k.float().T) * scale max_s = scores.max(dim=-1, keepdim=True).values exp_s = (scores - max_s).exp() sum_s = exp_s.sum(dim=-1, keepdim=True) p = exp_s / sum_s - ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16) + o = torch.matmul(p, v.float()) + return o.to(torch.bfloat16), (sum_s.log() + max_s) + + +def test_d2_regression(): + """Regression test matching existing Stage D1 pattern.""" + print("\n=== Regression test (hd=64, s_k=128) ===") + torch.manual_seed(42) + m = 128; n_kv = 128; hd = 64 + scale = 1.0 / math.sqrt(hd) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda') + + kernel = FmhaKernel(head_dim=hd, s_k=n_kv, use_smem_p=False) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Compile with first PV tile + v_tile = v[:, 0:pv_n_tile].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + + # Run PV tiles + o_unnorm = torch.zeros(m, hd, dtype=torch.float32, device='cuda') + for pv in range(n_pv_tiles): + v_tile = v[:, pv*pv_n_tile:(pv+1)*pv_n_tile].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile.zero_() + lse_tensor.zero_() + + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + compiled(mQ, mK, mV, mC, stream, mLSE) + + o_unnorm[:, pv*pv_n_tile:(pv+1)*pv_n_tile] = c_tile[:,:,0].float() + + # External normalization using LSE + lse = lse_tensor[:,0,0] # (m,) + row_sum = lse.exp() + o_norm = o_unnorm / row_sum.unsqueeze(-1) + o_bf16 = o_norm.to(torch.bfloat16) + + # Reference + ref, _ = reference_fmha(q[:,:,0], k[:,:,0], v, scale) cos = torch.nn.functional.cosine_similarity( - o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) + o_bf16.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) ).item() - print(f" cos (ext norm) = {cos:.6f}") - if cos >= 0.99: - print(" ✅ PASS") - else: - print(f" ❌ FAIL") + print(f" cos = {cos:.6f}") + assert cos >= 0.99, f"cosine too low: {cos}" + print(" ✅ PASS") + + +def test_d2_headpacked_128(): + """n_h=128, T=1 (Pro decode): M=128, heads packed into M.""" + print("\n=== n_h=128, T=1 (Pro decode, hd=64) ===") + torch.manual_seed(42) + n_h, T, s_k, hd = 128, 1, 128, 64 + scale = 1.0 / math.sqrt(hd) + + # Per-head Q + q_heads = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda') + # Pack heads into M: (n_h*T, hd) → (128, 64, 1) + q = q_heads.reshape(n_h * T, hd).unsqueeze(-1) + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + + kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + c_tile = torch.zeros(n_h * T, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor = torch.zeros(n_h * T, 1, 1, dtype=torch.float32, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + + o_unnorm = torch.zeros(n_h * T, hd, dtype=torch.float32, device='cuda') + for pv in range(n_pv_tiles): + v_tile = v[:, pv*pv_n_tile:(pv+1)*pv_n_tile].contiguous().unsqueeze(-1) + c_tile.zero_() + lse_tensor.zero_() + + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + compiled(mQ, mK, mV, mC, stream, mLSE) + + o_unnorm[:, pv*pv_n_tile:(pv+1)*pv_n_tile] = c_tile[:,:,0].float() + + lse = lse_tensor[:,0,0] + row_sum = lse.exp() + o_norm = o_unnorm / row_sum.unsqueeze(-1) + o_bf16 = o_norm.to(torch.bfloat16) + + # Per-head reference + o_ref = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda') + for h in range(n_h): + o_ref[h, 0], _ = reference_fmha(q_heads[h], k[:,:,0], v, scale) + o_ref_flat = o_ref.reshape(n_h * T, hd) + + cos = torch.nn.functional.cosine_similarity( + o_bf16.flatten().float().unsqueeze(0), o_ref_flat.flatten().float().unsqueeze(0) + ).item() + print(f" cos = {cos:.6f}") + assert cos >= 0.99, f"cosine too low: {cos}" + print(" ✅ PASS") + + +def test(): + print("=== D2: Head-Packed FMHA ===") + test_d2_regression() + test_d2_headpacked_128() + print("\n=== ALL TESTS PASSED ===") if __name__ == '__main__':