fix: add NO-OP TMEM round-trip to re-map O from MMA to epilog layout

This commit is contained in:
2026-05-23 02:50:53 +00:00
parent 45cf89a556
commit 3aba5cc6da

View File

@@ -330,6 +330,32 @@ class FmhaV3StageCMulti:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout ===
# The MMA writes O in the C-fragment TMEM layout, but epilogue_tma_store
# reads using get_tmem_load_op which expects a different layout. A NO-OP
# load-then-store through the hand-constructed atoms re-maps the data.
# TODO: eliminate this by using get_tmem_load_op for the normalize.
tTMrO_noop = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO_noop[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# === Final O normalization: O *= 1/row_sum ===
# TMEM round-trip using hand-constructed atoms.
# Known issue: hand-constructed Ld32x32bOp/St32x32bOp atoms introduce