fix: add NO-OP TMEM round-trip to re-map O from MMA to epilog layout
This commit is contained in:
@@ -330,6 +330,32 @@ class FmhaV3StageCMulti:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout ===
|
||||
# The MMA writes O in the C-fragment TMEM layout, but epilogue_tma_store
|
||||
# reads using get_tmem_load_op which expects a different layout. A NO-OP
|
||||
# load-then-store through the hand-constructed atoms re-maps the data.
|
||||
# TODO: eliminate this by using get_tmem_load_op for the normalize.
|
||||
tTMrO_noop = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO_noop[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
# TMEM round-trip using hand-constructed atoms.
|
||||
# Known issue: hand-constructed Ld32x32bOp/St32x32bOp atoms introduce
|
||||
|
||||
Reference in New Issue
Block a user