Try cute.copy(tma_c, sC_flat, gO) directly

This commit is contained in:
2026-05-27 05:29:51 +00:00
parent 2af767a90c
commit b39d7f1a14

View File

@@ -624,17 +624,10 @@ class FmhaKernel:
cute.arch.fence_proxy("async.shared", space="cta")
# Step 2: TMA store sC_flat -> GMEM
gO_qdl = cute.flat_divide(
mC, cute.select(self.pv_mma_tiler, mode=[0, 1])
)
gO = gO_qdl[None, None, None, Int32(0), Int32(0)]
tOsC, tOgO = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC_flat, 0, 2),
cute.group_modes(gO, 0, 2),
)
# Use cute.copy with tma_c directly (AutoCopy-style)
gO = cute.local_tile(mC, cute.select(self.pv_mma_tiler, (None, None, 0)), (Int32(0), Int32(0), Int32(0)))
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, tOsC, tOgO)
cute.copy(tma_c, sC_flat, gO)
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)