Split sC_flat into staged layout to match TMA atom decomposition

This commit is contained in:
2026-05-27 05:35:56 +00:00
parent a0e9f7534b
commit de2028b106

View File

@@ -628,13 +628,14 @@ class FmhaKernel:
# Step 2: TMA store sC_flat -> GMEM
# Use tCgC (already partitioned) for the GMEM side of TMA
# Transform tCgC layout (same as epilogue_tma_store)
tCgC_xfm = transform_partitioned_tensor_layout(tCgC)
tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile)
# sC_flat (128, pv_n_tile) -> split to match TMA stage: (128, pv_n_tile//2, 2)
sC_flat_staged = cute.logical_divide(sC_flat, cute.make_layout((128, self.pv_n_tile // 2, 2), stride=(self.pv_n_tile, 2, 1)))
tOsC, tOgO = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC_flat, 0, 2),
cute.group_modes(tCgC_epi, 0, 2),
sC_flat_staged,
tCgC_epi,
)
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, tOsC, tOgO)