diag: NO-OP round-trip before normalize on 2D pattern

This commit is contained in:
2026-05-23 02:32:40 +00:00
parent 6cf1f17904
commit 49bf6e8294

View File

@@ -407,6 +407,28 @@ class FmhaV3StageCMulti:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# DIAG: NO-OP TMEM round-trip before normalize
tTMrO_noop = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO_noop[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# === Final O normalization: O *= 1/row_sum ===
inv_row_sum = Float32(1.0) / row_sum