diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a6d8108d..fac28721 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -487,15 +487,14 @@ class FmhaKernel: epilog_sync_bar.arrive_and_wait() # TMA store: SMEM → GMEM - # Use the TMA partition from the kernel setup + # Reuse the existing TMA partition (tCgC) which was set up at kernel start. + # sC was written by the correction epilog. TMA reads from sC → GMEM via tCgC. c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) c_pipe.producer_acquire() - # Use epi_tile-matched GMEM partition - gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) - # TMA store from sC (2D view) to gC (2D view) - sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA - cute.copy(tma_c, sC_epi, gC_epi) + # TMA store from sC to GMEM using the pre-partitioned gC + gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler,(None,0,None)),(None,None,None)) + cute.copy(tma_c, cute.select(sC, mode=[0, 1]), cute.select(gC, mode=[0, 1])) c_pipe.producer_commit() cute.arch.gpu_bar_sync() c_pipe.producer_tail()