fix: use tma_partition for TMA store in correction_epilog

This commit is contained in:
2026-05-23 01:20:09 +00:00
parent e70230825a
commit 6b2dda02ec

View File

@@ -460,25 +460,22 @@ class FmhaV3StageCMulti:
cute.arch.fence_proxy("async.shared", space="cta")
# TMA store: SMEM → GMEM
epi_s_tile = cute.select(self.c_smem_s, mode=[0, 1])
tma_c_epi, mC_epi = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_s_tile, self.epi_tile
)
tCgC_epi = cute.local_tile(mC_epi, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None))
tCsC_epi = cute.local_tile(sC, cute.slice_(self.epi_tile, (None, None)), (None, None))
# Sync before TMA store — all softmax warps must finish SMEM writes
# TMA store: SMEM → GMEM (reuse existing tma_c from kernel setup)
# Sync all softmax warps before TMA store
softmax_all_bar = pipeline.NamedBarrier(
barrier_id=5, num_threads=32 * len(self.epilogue_warp_id)
)
softmax_all_bar.arrive_and_wait()
# Warp 0 does the TMA store
if sfw_idx < 32:
cute.copy(tma_c_epi, tCsC_epi, tCgC_epi)
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
# Partition SMEM and GMEM for TMA store
tCsC, tCgC_tma = cpasync.tma_partition(
tma_c, 0, cute.make_layout((1,)),
cute.group_modes(sC, 0, 2),
cute.group_modes(tCgC, 0, 3),
)
cute.copy(tma_c, tCsC, tCgC_tma)
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)