Try cute.copy(tma_c, sC_flat, gO) directly
This commit is contained in:
@@ -624,17 +624,10 @@ class FmhaKernel:
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
# Step 2: TMA store sC_flat -> GMEM
|
||||
gO_qdl = cute.flat_divide(
|
||||
mC, cute.select(self.pv_mma_tiler, mode=[0, 1])
|
||||
)
|
||||
gO = gO_qdl[None, None, None, Int32(0), Int32(0)]
|
||||
tOsC, tOgO = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC_flat, 0, 2),
|
||||
cute.group_modes(gO, 0, 2),
|
||||
)
|
||||
# Use cute.copy with tma_c directly (AutoCopy-style)
|
||||
gO = cute.local_tile(mC, cute.select(self.pv_mma_tiler, (None, None, 0)), (Int32(0), Int32(0), Int32(0)))
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, tOsC, tOgO)
|
||||
cute.copy(tma_c, sC_flat, gO)
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user