diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 45767c22..7d5b8594 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -460,25 +460,22 @@ class FmhaV3StageCMulti: cute.arch.fence_proxy("async.shared", space="cta") - # TMA store: SMEM → GMEM - epi_s_tile = cute.select(self.c_smem_s, mode=[0, 1]) - tma_c_epi, mC_epi = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileS2GOp(), c, epi_s_tile, self.epi_tile - ) - tCgC_epi = cute.local_tile(mC_epi, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None)) - tCsC_epi = cute.local_tile(sC, cute.slice_(self.epi_tile, (None, None)), (None, None)) - - # Sync before TMA store — all softmax warps must finish SMEM writes + # TMA store: SMEM → GMEM (reuse existing tma_c from kernel setup) + # Sync all softmax warps before TMA store softmax_all_bar = pipeline.NamedBarrier( barrier_id=5, num_threads=32 * len(self.epilogue_warp_id) ) softmax_all_bar.arrive_and_wait() - # Warp 0 does the TMA store - if sfw_idx < 32: - cute.copy(tma_c_epi, tCsC_epi, tCgC_epi) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) + # Partition SMEM and GMEM for TMA store + tCsC, tCgC_tma = cpasync.tma_partition( + tma_c, 0, cute.make_layout((1,)), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC, 0, 3), + ) + cute.copy(tma_c, tCsC, tCgC_tma) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)