diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 9e52d7b7..0537f382 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -624,17 +624,10 @@ class FmhaKernel: cute.arch.fence_proxy("async.shared", space="cta") # Step 2: TMA store sC_flat -> GMEM - gO_qdl = cute.flat_divide( - mC, cute.select(self.pv_mma_tiler, mode=[0, 1]) - ) - gO = gO_qdl[None, None, None, Int32(0), Int32(0)] - tOsC, tOgO = cpasync.tma_partition( - tma_c, 0, cute.make_layout(1), - cute.group_modes(sC_flat, 0, 2), - cute.group_modes(gO, 0, 2), - ) + # Use cute.copy with tma_c directly (AutoCopy-style) + gO = cute.local_tile(mC, cute.select(self.pv_mma_tiler, (None, None, 0)), (Int32(0), Int32(0), Int32(0))) if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, tOsC, tOgO) + cute.copy(tma_c, sC_flat, gO) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)