diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 8270c49c..ea84ee35 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -330,6 +330,32 @@ class FmhaV3StageCMulti: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() + # === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout === + # The MMA writes O in the C-fragment TMEM layout, but epilogue_tma_store + # reads using get_tmem_load_op which expects a different layout. A NO-OP + # load-then-store through the hand-constructed atoms re-maps the data. + # TODO: eliminate this by using get_tmem_load_op for the normalize. + tTMrO_noop = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO_noop[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + # === Final O normalization: O *= 1/row_sum === # TMEM round-trip using hand-constructed atoms. # Known issue: hand-constructed Ld32x32bOp/St32x32bOp atoms introduce