From 7189165a6735653f9bb6e2b64be04ee57c7370be Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 21 May 2026 19:26:15 +0000 Subject: [PATCH] WIP: TMEM vector bridge not working (same cosine 0.513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit row_sum is PROVEN correct (29.25 vs 29.22 for row 0, ratio 1.001). The ONLY bug is QK→PV row mapping in C9 normalization. Tried: composition(tStS,(128,1)) for write, composition(tOtO,(128,1)) for read. Same result — the composition preserves the fragments internal thread-to-address mapping, so the same thread writes and reads the same TMEM address regardless of which fragment layout is used for the composition. Need: absolute row-coordinate indexed TMEM vector. Each QK thread writes inv_row_sum to vec[QK_row_id], each PV thread reads from vec[PV_row_id]. The row_id comes from the identity tensor coordinate. Alternative: implement FMHA correction_epilog pattern with dedicated correction warp group that reads row metadata from the vector. --- tests/unit/test_fmha_v3_softmax.py | 75 +++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/tests/unit/test_fmha_v3_softmax.py b/tests/unit/test_fmha_v3_softmax.py index dcc635da..e49d69b0 100644 --- a/tests/unit/test_fmha_v3_softmax.py +++ b/tests/unit/test_fmha_v3_softmax.py @@ -52,7 +52,8 @@ class FmhaV3Softmax: p_end = self.tmem_p0_offset + p_cols_fp32 # 32 + 64 = 96 s_cols = self.qk_mma_tiler[1] # 128 o_after = max(s_cols, p_end) # 128 - self.tmem_o0_offset = ((o_after + 31) // 32) * 32 # align to 32 = 128 + self.tmem_o0_offset = ((o_after + 31) // 32) * 32 + self.tmem_vec_offset = 0 # Reuse S region for per-row inv_row_sum vector # align to 32 = 128 self.tmem_vec_offset = 0 # Reuse S region (free after softmax loop) o_cols = find_tmem_tensor_col_offset(tOtO) # footprint of O total = self.tmem_o0_offset + o_cols @@ -72,7 +73,7 @@ class FmhaV3Softmax: self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() # # s_k hardcoded # BROKEN in @cute.jit - # FMHA-style V: reconstruct as (HEAD_DIM, 128, 1) MN-major + # FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major v_fmha = cute.make_tensor( v.iterator, cute.make_layout( @@ -373,9 +374,37 @@ class FmhaV3Softmax: row_sum = row_sum + tile_sum - # --- C9: SKIPPED for debug (no normalization) --- + # --- C9: Final normalization via O TMEM rescale --- pv_done_bar.arrive_and_wait() - # O is unnormalized in TMEM. Use standard epilogue_tma_store with identity. + # Store final row_sum to TMEM vector (per-row, using QK partition) + tTMEM_STORE_VECrS_final = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype) + tTMEM_STORE_VECrS_final[0] = row_sum + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS_final, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + + # Read vector back: per-row row_sum using QK partition coordinates + tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + inv_row_sum = cutlass.Float32(1.0) / tTMEM_LOAD_VECrS[0] + + # Normalize O in TMEM + tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype) + for i in range(o_col_tiles): + tTMrO_i_ = tTMrO_final[None, i] + tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO_final.shape[0])) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout) + cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i) + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * inv_row_sum + cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + + # Now O in TMEM is normalized. Use standard epilogue_tma_store with identity. tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) @@ -401,8 +430,7 @@ def test(): c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda") qf = q[:,:,0].float(); kf = k[:,:,0].float() attn = qf @ kf.T / math.sqrt(hd) - P = torch.exp(attn - attn.max(dim=-1, keepdim=True)[0]) - ref = P @ v.float() # unnormalized P@V + ref = torch.softmax(attn, dim=-1) @ v.float() mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) @@ -418,7 +446,40 @@ def test(): out = c[:,:,0].float() cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() max_err = (out - ref).abs().max().item() - print(f"FMHA softmax (no C9 norm) n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True) + print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True) + +if __name__ == "__main__": + test() + + +def test(): + import math + torch.manual_seed(42) + for n in [128, 256, 384]: + m, hd = 128, HEAD_DIM + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device="cuda") + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device="cuda") + v = torch.randn(n, hd, dtype=torch.bfloat16, device="cuda") + v_kernel = v.unsqueeze(-1) + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda") + qf = q[:,:,0].float(); kf = k[:,:,0].float() + attn = qf @ kf.T / math.sqrt(hd) + ref = torch.softmax(attn, dim=-1) @ v.float() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = FmhaV3Softmax() + print(f"n={n}: Compiling...", flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print(f"n={n}: Running...", flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True) if __name__ == "__main__": test()