Use tCgC_epi (transformed) for GMEM side of TMA partition

This commit is contained in:
2026-05-27 05:34:40 +00:00
parent b02e103ac0
commit a0e9f7534b

View File

@@ -627,12 +627,14 @@ class FmhaKernel:
cute.arch.fence_proxy("async.shared", space="cta")
# Step 2: TMA store sC_flat -> GMEM
gO = cute.local_tile(mCSimple, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None))
# Group modes to match: sC_flat is 2D, gO needs to be grouped to 2D
# Use tCgC (already partitioned) for the GMEM side of TMA
# Transform tCgC layout (same as epilogue_tma_store)
tCgC_xfm = transform_partitioned_tensor_layout(tCgC)
tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile)
tOsC, tOgO = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC_flat, 0, 2),
cute.group_modes(gO, 0, len(cute.shape(gO))),
cute.group_modes(tCgC_epi, 0, 2),
)
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, tOsC, tOgO)