diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 8349c09b..3af9faf0 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -226,19 +226,27 @@ class FmhaV3RealSoftmax: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # O normalize setup: use FULL O layout (no sub-tiling) - # Match the P store pattern for TMEM access - tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(64)), self.acc_dtype) - tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(64)), self.acc_dtype) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO0) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO0) - thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO0) + # O normalize setup: use the SAME base pointer as the epilogue + # The epilogue reads O from tmem_ptr + tmem_o0_offset. + # We must use the same base to access the correct TMEM columns. + tCtO_norm = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tOtO.layout) cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) tOcO = pv_thr.partition_C(cO) - tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO) - tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO0) + # Sub-tile the O layout for the normalize copy + corr_tile_size = 16 + tOtO_i_layout = cute.composition(tCtO_norm.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_norm_i = cute.make_tensor(tCtO_norm.iterator, tOtO_i_layout) + tOcO_norm_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_norm_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_norm_i) + thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx) + thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx) + tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_norm_i) + tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_norm_i) + tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_norm_i) row_max = -Float32.inf row_sum = Float32(0.0) @@ -298,11 +306,26 @@ class FmhaV3RealSoftmax: # Final O normalization: O = O / row_sum if row_sum != Float32(0.0): inv_row_sum = Float32(1.0) / row_sum - tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO, tTMrO) - for j in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[j] = tTMrO[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_OtO) + n_corr = HEAD_DIM // corr_tile_size + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype + ) + for ci in range(n_corr): + tTMrO_ci_ = tTMrO[None, ci] + tTMrO_ci_layout = cute.composition( + tTMrO_ci_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout) + tTMEM_LOAD_OtO_ci = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_ci = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci) + for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True): + tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci) # Epilogue: TMEM -> SMEM -> GMEM via TMA store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)