D1.5: Fix TMA store - use flat_divide on tCgC instead of local_tile on mC
This commit is contained in:
@@ -492,18 +492,17 @@ class FmhaKernel:
|
||||
# The normalized O is now in sC (written by the correction epilog).
|
||||
# The tma_c was created with CopyBulkTensorTileS2GOp for c (3D) and epi_s (2D SMEM layout).
|
||||
# We need to partition sC and the GMEM output for the TMA copy.
|
||||
# Get 2D views: sC has 4D layout ((128,16),1,(4,2),1), group to 2D.
|
||||
# gC: use local_tile with 3D tile and 3D coordinate.
|
||||
epi_tile_3d = (epi_tile[0], epi_tile[1], 1)
|
||||
gC_epi = cute.local_tile(mC, epi_tile_3d, (Int32(0), Int32(0), Int32(0)))
|
||||
# Use flat_divide on the already-partitioned tCgC (same pattern
|
||||
# as CUTLASS epilogue_tma_store), then tma_partition.
|
||||
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(gC_epi, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
# One TMA store for the full output tile
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC, bSG_gC)
|
||||
cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user