diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 0233fd4c..f02eaa6a 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -547,22 +547,31 @@ class FmhaKernel: # Step 2: TMA store sC -> GMEM # Use cpasync.tma_partition (same as epilogue_tma_store) cute.arch.fence_proxy("async.shared", space="cta") - c_pipe = pipeline.PipelineTmaStore.create( - num_stages=self.num_c_stage, - producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + # Sync barrier for SMEM->GMEM ordering + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), ) - c_pipe.producer_acquire() - # Transform tCgC layout (same as epilogue_tma_store) - tCgC = transform_partitioned_tensor_layout(tCgC) - tCgC_epi = cute.flat_divide(tCgC, epi_tile) - # Create TMA partition from sC and gC + epilog_sync_barrier.arrive_and_wait() + # Transform tCgC layout and partition for TMA + tCgC_xfm = transform_partitioned_tensor_layout(tCgC) + tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile) bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), cute.group_modes(tCgC_epi, 0, 2), ) - cute.copy(tma_c, bSG_sC[None, ...], bSG_gC[None, ...]) - c_pipe.producer_commit() + # TMA store: only the first epilogue warp does the copy + c_pipe = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + ) + c_pipe.producer_acquire() + if warp_idx == self.epilogue_warp_id[0]: + cute.copy(tma_c, bSG_sC[(None, Int32(0))], bSG_gC[(None, Int32(0))]) + c_pipe.producer_commit() + c_pipe.producer_acquire() + epilog_sync_barrier.arrive_and_wait() c_pipe.producer_tail() tmem.relinquish_alloc_permit()