diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a5c9a533..4a643645 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -491,8 +491,8 @@ class FmhaKernel: # Step 6: TMA store SMEM → GMEM # The normalized O is now in sC (written by the correction epilog). # TMA store from sC to the output tensor in GMEM. - # Use cpasync.tma_partition to set up the SMEM/GMEM partition. - gC = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) + # Use the pre-partitioned tCgC (GMEM partition) and sC (SMEM buffer). + gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler,(None,0,None)),(None,None,None)) tCgC_epi = cute.flat_divide(tCgC, epi_tile) bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1),