D1.3: DIAGNOSTIC - test epilogue_tma_store raw PV without any round-trips

This commit is contained in:
2026-05-23 20:57:13 +00:00
parent 9e158dfc9f
commit 36f49e574c

View File

@@ -369,39 +369,9 @@ class FmhaKernel:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# === O normalization: TMEM -> reg (scale by 1/row_sum) -> TMEM ===
# Uses hand-constructed Ld32x32bOp/St32x32bOp atoms (same as correction_rescale).
# The layout mismatch in these atoms introduces ~3% error per round-trip,
# but the correction_rescale atoms (same construction) already use this path.
# TODO: Replace with get_tmem_load_op-derived atoms for zero error.
inv_row_sum = Float32(1.0) / row_sum
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Epilogue: TMEM → SMEM → GMEM via TMA store.
# Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) internally.
# Since O is already normalized in TMEM, we apply identity epilogue_op.
# === DIAGNOSTIC: Test epilogue_tma_store WITHOUT any round-trips ===
# If get_tmem_load_op reads O correctly from TMEM, this should give cos 0.9999
# (un-normalized, just raw PV sum). Then we can add normalization back.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage