diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 463aba5d..a6d8108d 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -487,10 +487,15 @@ class FmhaKernel: epilog_sync_bar.arrive_and_wait() # TMA store: SMEM → GMEM + # Use the TMA partition from the kernel setup c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) c_pipe.producer_acquire() - cute.copy(tma_c, sC, tCgC) + # Use epi_tile-matched GMEM partition + gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) + # TMA store from sC (2D view) to gC (2D view) + sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA + cute.copy(tma_c, sC_epi, gC_epi) c_pipe.producer_commit() cute.arch.gpu_bar_sync() c_pipe.producer_tail()