diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index c507ad72..c96d8b7d 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -286,7 +286,32 @@ class FmhaV3StageCMulti: row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) - # Per-tile softmax loop. + # === O rescale setup (paired atoms for TMEM O read-modify-write) === + corr_tile_size = 16 + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) + tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + tmem_load_o_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) + tmem_store_o_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) + thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) + thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) + tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) + tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) + n_corr_tiles = HEAD_DIM // corr_tile_size + + # Per-tile softmax loop with online O rescale. # Online softmax row_max/row_sum tracking is maintained, but the # in-place TMEM O rescale (which would multiply existing O by # exp2(old_max - new_max) before PV[kt]) is DISABLED — this is the @@ -344,69 +369,39 @@ class FmhaV3StageCMulti: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() + # === Per-tile O rescale: O *= acc_scale for kt > 0 === + if kt > 0: + for i in range(n_corr_tiles): + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO) + cute.arch.fence_view_async_tmem_load() + for k in cutlass.range(cute.size(tTMrO), vectorize=True): + tTMrO[k] = tTMrO[k] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + si_handle.release() softmax_done_bar.arrive() - # === Reference-style scaled epilogue (no TMEM round-trip) === - # - # Pattern (mirrors CUTLASS Blackwell FMHA reference's - # correction_epilog): for each column sub-tile, - # 1. TMEM -> registers via PAIRED tmem_load atom - # 2. scale in registers (1/row_sum) - # 3. FP32 -> BF16 conversion in registers - # 4. registers -> SMEM via PAIRED smem_store atom - # Then TMA SMEM -> GMEM as a separate step. - # - # Critical: the load and store atoms MUST be a matched pair. - # Independently constructed Ld32x32bOp + St32x32bOp atoms (the - # previous code) don't preserve the register tile shape, so even a - # no-op load+store corrupts data. Using utils.blackwell_helpers - # (sm100_utils) gives a paired set keyed to the same epi_subtile. - # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === O normalization via TMEM load → scale → TMEM store === - # Matches CUTLASS reference's correction_rescale pattern exactly. + # === Final O normalization: O *= 1/row_sum === + inv_row_sum = Float32(1.0) / row_sum - corr_tile_size = 16 - - cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO = pv_thr.partition_C(cO) - - tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) - - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - - # 2D register tensor: (frg_shape, n_corr_tiles) tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype ) - inv_row_sum = Float32(1.0) / row_sum - - for i in range(HEAD_DIM // corr_tile_size): + for i in range(n_corr_tiles): tTMrO_i_ = tTMrO[None, i] tTMrO_i_layout = cute.composition( tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])