diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index c3d2889b..8316878a 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -627,12 +627,14 @@ class FmhaKernel: cute.arch.fence_proxy("async.shared", space="cta") # Step 2: TMA store sC_flat -> GMEM - gO = cute.local_tile(mCSimple, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None)) - # Group modes to match: sC_flat is 2D, gO needs to be grouped to 2D + # Use tCgC (already partitioned) for the GMEM side of TMA + # Transform tCgC layout (same as epilogue_tma_store) + tCgC_xfm = transform_partitioned_tensor_layout(tCgC) + tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile) tOsC, tOgO = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC_flat, 0, 2), - cute.group_modes(gO, 0, len(cute.shape(gO))), + cute.group_modes(tCgC_epi, 0, 2), ) if warp_idx == self.epilogue_warp_id[0]: cute.copy(tma_c, tOsC, tOgO)