diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 2722e10a..4ce2b669 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -492,18 +492,17 @@ class FmhaKernel: # The normalized O is now in sC (written by the correction epilog). # The tma_c was created with CopyBulkTensorTileS2GOp for c (3D) and epi_s (2D SMEM layout). # We need to partition sC and the GMEM output for the TMA copy. - # Get 2D views: sC has 4D layout ((128,16),1,(4,2),1), group to 2D. - # gC: use local_tile with 3D tile and 3D coordinate. - epi_tile_3d = (epi_tile[0], epi_tile[1], 1) - gC_epi = cute.local_tile(mC, epi_tile_3d, (Int32(0), Int32(0), Int32(0))) + # Use flat_divide on the already-partitioned tCgC (same pattern + # as CUTLASS epilogue_tma_store), then tma_partition. + tCgC_epi = cute.flat_divide(tCgC, epi_tile) bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), - cute.group_modes(gC_epi, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), ) # One TMA store for the full output tile if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC, bSG_gC) + cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)