diff --git a/tests/fmha_v3_stage_c_example7.py b/tests/fmha_v3_stage_c_example7.py index 6e5ae19b..22a0d4b9 100644 --- a/tests/fmha_v3_stage_c_example7.py +++ b/tests/fmha_v3_stage_c_example7.py @@ -289,27 +289,27 @@ class FmhaV3StageCMulti: # Per-tile softmax loop. # Online softmax row_max/row_sum tracking is maintained, but the # in-place TMEM O rescale (which would multiply existing O by - # O rescale setup: same correction_rescale pattern as the final normalize - corr_tile_size_rs = 16 - cO_rs = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO_rs = pv_thr.partition_C(cO_rs) - tOtO_rs_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size_rs))) - tOcO_rs_i_layout = cute.composition(tOcO_rs.layout, cute.make_layout((128, corr_tile_size_rs))) - tOtO_rs_i = cute.make_tensor(tOtO0.iterator, tOtO_rs_i_layout) - tOcO_rs_i = cute.make_tensor(tOcO_rs.iterator, tOcO_rs_i_layout) - tmem_load_o_rs_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype) - tmem_store_o_rs_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype) - tiled_tmem_load_o_rs = tcgen05.make_tmem_copy(tmem_load_o_rs_atom, tOtO_rs_i) - tiled_tmem_store_o_rs = tcgen05.make_tmem_copy(tmem_store_o_rs_atom, tOtO_rs_i) - thr_tmem_load_o_rs = tiled_tmem_load_o_rs.get_slice(sfw_idx) - thr_tmem_store_o_rs = tiled_tmem_store_o_rs.get_slice(sfw_idx) - tTMEM_LOAD_OtO_rs = thr_tmem_load_o_rs.partition_S(tOtO_rs_i) - tTMEM_LOAD_OcO_rs = thr_tmem_load_o_rs.partition_D(tOcO_rs_i) - tTMEM_STORE_OtO_rs = thr_tmem_store_o_rs.partition_D(tOtO_rs_i) - tTMrO_rs = cute.make_rmem_tensor( - (tTMEM_LOAD_OcO_rs.shape, 128 // corr_tile_size_rs), self.acc_dtype) + # O rescale + final normalize setup: single set of correction_rescale tensors + corr_tile_size = 16 + cO_corr = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO_corr = pv_thr.partition_C(cO_corr) + tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO_corr.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO_corr.iterator, tOcO_i_layout) + tmem_load_o_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tmem_store_o_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) + thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) + thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) + tTMEM_LOAD_OtO = thr_tmem_load_o.partition_S(tOtO_i) + tTMEM_LOAD_OcO = thr_tmem_load_o.partition_D(tOcO_i) + tTMEM_STORE_OtO = thr_tmem_store_o.partition_D(tOtO_i) + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOAD_OcO.shape, 128 // corr_tile_size), self.acc_dtype) row_max = -Float32.inf row_sum = Float32(0.0) @@ -345,24 +345,23 @@ class FmhaV3StageCMulti: row_sum *= acc_scale # O rescale: multiply existing O by acc_scale = exp2(old_max - new_max) - # Uses the same correction_rescale pattern verified for final normalize. if kt > 0: - for ci in range(HEAD_DIM // corr_tile_size_rs): - tTMrO_rs_i_ = tTMrO_rs[None, ci] - tTMrO_rs_i_layout = cute.composition( - tTMrO_rs_i_.layout, cute.make_layout(tTMrO_rs.shape[0]) + for ci in range(HEAD_DIM // corr_tile_size): + tTMrO_i_ = tTMrO[None, ci] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) ) - tTMrO_rs_i = cute.make_tensor(tTMrO_rs_i_.iterator, tTMrO_rs_i_layout) - tTMEM_LOAD_OtO_rs_i = cute.make_tensor( - tTMEM_LOAD_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_LOAD_OtO_rs.layout + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOAD_OtO_i = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout ) - tTMEM_STORE_OtO_rs_i = cute.make_tensor( - tTMEM_STORE_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_STORE_OtO_rs.layout + tTMEM_STORE_OtO_i = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout ) - cute.copy(tiled_tmem_load_o_rs, tTMEM_LOAD_OtO_rs_i, tTMrO_rs_i) - for j in cutlass.range(cute.size(tTMrO_rs_i), vectorize=True): - tTMrO_rs_i[j] = tTMrO_rs_i[j] * acc_scale - cute.copy(tiled_tmem_store_o_rs, tTMrO_rs_i, tTMEM_STORE_OtO_rs_i) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i) + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i) cute.arch.fence_view_async_tmem_store() # Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum, @@ -405,44 +404,8 @@ class FmhaV3StageCMulti: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === O normalization via TMEM load → scale → TMEM store === - # Matches CUTLASS reference's correction_rescale pattern exactly. - - corr_tile_size = 16 - - cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO = pv_thr.partition_C(cO) - - tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) - - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - - # 2D register tensor: (frg_shape, n_corr_tiles) - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - + # === Final O normalization: 1/row_sum === + # Reuses the same corr_tile_size + tiled_tmem_load_o/store_o from above. inv_row_sum = Float32(1.0) / row_sum for i in range(HEAD_DIM // corr_tile_size): @@ -451,17 +414,17 @@ class FmhaV3StageCMulti: tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) ) tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout + tTMEM_LOAD_OtO_i = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout + tTMEM_STORE_OtO_i = cute.make_tensor( + tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i) for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): tTMrO_i[j] = tTMrO_i[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i) cute.arch.fence_view_async_tmem_store()