From 86a4719afe6c0141efbbcb627832d3eb73b3ed6f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 19:17:03 +0000 Subject: [PATCH] O normalize with full layout (no sub-tiling), Repetition(64) --- tests/fmha_v3_real_softmax.py | 42 ++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index aa3096ac..8349c09b 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -226,23 +226,19 @@ class FmhaV3RealSoftmax: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # O normalize setup: sub-tile O for TMEM read-modify-write - cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO = pv_thr.partition_C(cO) - corr_tile_size = 32 - tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) - tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) + # O normalize setup: use FULL O layout (no sub-tiling) + # Match the P store pattern for TMEM access + tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(64)), self.acc_dtype) + tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(64)), self.acc_dtype) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO0) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO0) thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx) thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i) - tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i) - tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i) + tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO0) + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) + tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO) + tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO0) row_max = -Float32.inf row_sum = Float32(0.0) @@ -299,7 +295,14 @@ class FmhaV3RealSoftmax: si_handle.release() softmax_done_bar.arrive() - # Final O normalization: DISABLED for NO-OP test + # Final O normalization: O = O / row_sum + if row_sum != Float32(0.0): + inv_row_sum = Float32(1.0) / row_sum + tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO, tTMrO) + for j in cutlass.range(cute.size(tTMrO), vectorize=True): + tTMrO[j] = tTMrO[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_OtO) # Epilogue: TMEM -> SMEM -> GMEM via TMA store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) @@ -322,14 +325,13 @@ def test(): v_kernel = v.unsqueeze(-1) c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') - # Reference: unnormalized (just softmax numerators @ V, no 1/row_sum) + # Reference: proper softmax qf = q[:, :, 0].float() kf = k[:, :, 0].float() scale = 1.0 / math.sqrt(hd) attn = qf @ kf.T * scale - attn_max = attn.max(dim=-1, keepdim=True)[0] - attn_unnorm = torch.exp(attn - attn_max) - ref = attn_unnorm @ v.float() + attn = torch.softmax(attn, dim=-1) + ref = attn @ v.float() mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))