diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 492f423e..e5467563 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -297,17 +297,32 @@ class FmhaV3RealSoftmax: softmax_done_bar.arrive() # Final O normalization: O = O / row_sum + # Uses the CUTLASS reference's sub-tile approach: + # Load O from TMEM in sub-tiles of corr_tile_size columns, + # multiply by 1/row_sum, write back. if row_sum != Float32(0.0): inv_row_sum = Float32(1.0) / row_sum - tTMEM_LOAD_OrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO, tTMEM_LOAD_OrO) - cute.arch.fence_view_async_tmem_load() - # The register tensor from the O partition is 2D: (frg, corr_tile) - for fi in range(cute.size(tTMEM_LOAD_OrO, mode=[0])): - for fj in range(cute.size(tTMEM_LOAD_OrO, mode=[1])): - tTMEM_LOAD_OrO[fi, fj] = tTMEM_LOAD_OrO[fi, fj] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMEM_LOAD_OrO, tTMEM_STORE_OtO) - cute.arch.fence_view_async_tmem_store() + # Register tensor: (frg, n_corr_tiles) where n_corr = 128/corr_tile_size + n_corr = 128 // corr_tile_size + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype + ) + for ci in range(n_corr): + tTMrO_ci_ = tTMrO[None, ci] + tTMrO_ci_layout = cute.composition( + tTMrO_ci_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout) + tTMEM_LOAD_OtO_ci = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_ci = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci) + for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True): + tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci) # Epilogue: TMEM -> SMEM -> GMEM via TMA store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)