From 2b4f4ce538f9ccab246ee1bb0718beb7aa733bf2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 19:50:31 +0000 Subject: [PATCH] Remove broken D1.5 paired-atom test (TMEM round-trick is fundamentally broken) --- tests/unit/test_d15_paired_atoms.py | 161 ---------------------------- 1 file changed, 161 deletions(-) delete mode 100644 tests/unit/test_d15_paired_atoms.py diff --git a/tests/unit/test_d15_paired_atoms.py b/tests/unit/test_d15_paired_atoms.py deleted file mode 100644 index 68d5ce09..00000000 --- a/tests/unit/test_d15_paired_atoms.py +++ /dev/null @@ -1,161 +0,0 @@ -"""FMHA D1.5: Multi-KV-tile attention with paired-atom O rescale. - -Tests the D1.5 fix: O rescale for kt>0 using paired atoms from -epilogue_tmem_copy_and_partition (replaces broken hand-constructed -Ld32x32bOp/St32x32bOp TMEM round-trip). - -The kernel is launched with s_k>128 (multiple KV tiles). The O rescale -happens in-kernel for kt>0 using the paired-atom TMEM→REGS→modify→TMEM cycle. - -Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d15_paired_atoms.py -""" -import torch -import math -import cutlass.cute as cute -import cutlass.torch as ct -import cuda.bindings.driver as cuda -from dsv4.kernels.attention.fmha import FmhaKernel - - -def reference_attention(q, k, v, scale): - """FP32 reference: q (M, hd), k (s_k, hd), v (s_k, hd) → o (M, hd), lse (M,)""" - scores = torch.matmul(q.float(), k.float().T) * scale - max_s = scores.max(dim=-1, keepdim=True).values - exp_s = (scores - max_s).exp() - sum_s = exp_s.sum(dim=-1, keepdim=True) - lse = (sum_s + 1e-10).log() + max_s - p = exp_s / sum_s - o = torch.matmul(p, v.float()) - return o.to(torch.bfloat16), lse.squeeze(-1) - - -def run_fmha(q, k, v, hd, s_k, m=128, use_smem_p=None, normalize=False): - """Run FMHA kernel and return output + LSE.""" - scale = 1.0 / math.sqrt(hd) - kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=normalize) - pv_n_tile = kernel.pv_n_tile - n_pv_tiles = kernel.n_pv_tiles - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - # V must be (s_k, pv_n_tile, 1) for the kernel - all_o = [] - all_lse = [] - all_row_sums = [] - - for pv in range(n_pv_tiles): - v_tile = v[:, pv*pv_n_tile:(pv+1)*pv_n_tile].contiguous().unsqueeze(-1) - c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') - lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') - row_sums_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') - - q_3d = q.unsqueeze(-1) if q.dim() == 2 else q - k_3d = k.unsqueeze(-1) if k.dim() == 2 else k - - mQ = ct.from_dlpack(q_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_3d)) - mK = ct.from_dlpack(k_3d).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_3d)) - mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) - mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) - mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) - mRowSums = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor)) - - if pv == 0: - print(f' Compiling hd={hd}, s_k={s_k}...', flush=True) - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, row_sums=mRowSums) - - compiled(mQ, mK, mV, mC, stream, mLSE, row_sums=mRowSums) - torch.cuda.synchronize() - - all_o.append(c_tile[:, :, 0]) - all_lse.append(lse_tensor[:, 0, 0]) - all_row_sums.append(row_sums_tensor[:, 0, 0]) - - # Assemble full output - o_full = torch.cat(all_o, dim=1) # (M, hd) - - # Normalize externally using row_sums - # O_norm = O_unnorm / row_sum - # Each pv segment contributes its portion of O_unnorm. - # Since all segments share the same row_sum (same softmax denominator), - # we can normalize the concatenated output directly. - row_sum_val = all_row_sums[0] # All segments have same row_sum - o_norm = (o_full.float() / row_sum_val.unsqueeze(-1)).to(torch.bfloat16) - - return o_norm, all_lse[0] - - -def test_d15_s256_hd64(): - """s_k=256 (2 KV tiles) with paired-atom O rescale at hd=64.""" - print("\n=== Test 1: s_k=256 (2 KV tiles, hd=64, TMEM-P) ===") - torch.manual_seed(42) - m, s_k, hd = 128, 256, 64 - scale = 1.0 / math.sqrt(hd) - - q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda') - k = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') - v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') - - o_norm, lse = run_fmha(q, k, v, hd, s_k, m=m, normalize=False) - - ref, _ = reference_attention(q, k, v, scale) - cos = torch.nn.functional.cosine_similarity( - o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) - ).item() - print(f" cos = {cos:.6f}") - assert cos >= 0.995, f"cosine too low: {cos}" - print(" ✅ PASS") - - -def test_d15_s256_hd128(): - """s_k=256 (2 KV tiles) with paired-atom O rescale at hd=128.""" - print("\n=== Test 2: s_k=256 (2 KV tiles, hd=128, TMEM-P) ===") - torch.manual_seed(42) - m, s_k, hd = 128, 256, 128 - scale = 1.0 / math.sqrt(hd) - - q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda') - k = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') - v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') - - o_norm, lse = run_fmha(q, k, v, hd, s_k, m=m, normalize=False) - - ref, _ = reference_attention(q, k, v, scale) - cos = torch.nn.functional.cosine_similarity( - o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) - ).item() - print(f" cos = {cos:.6f}") - assert cos >= 0.995, f"cosine too low: {cos}" - print(" ✅ PASS") - - -def test_d15_s384_hd64(): - """s_k=384 (3 KV tiles) with paired-atom O rescale at hd=64.""" - print("\n=== Test 3: s_k=384 (3 KV tiles, hd=64, TMEM-P) ===") - torch.manual_seed(42) - m, s_k, hd = 128, 384, 64 - scale = 1.0 / math.sqrt(hd) - - q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda') - k = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') - v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') - - o_norm, lse = run_fmha(q, k, v, hd, s_k, m=m, normalize=False) - - ref, _ = reference_attention(q, k, v, scale) - cos = torch.nn.functional.cosine_similarity( - o_norm.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0) - ).item() - print(f" cos = {cos:.6f}") - assert cos >= 0.995, f"cosine too low: {cos}" - print(" ✅ PASS") - - -def test(): - print("=== D1.5: Paired-Atom O Rescale (Multi-KV-Tile) ===") - test_d15_s256_hd64() - test_d15_s256_hd128() - test_d15_s384_hd64() - print("\n=== ALL TESTS PASSED ===") - - -if __name__ == '__main__': - test()