fix: use flat_divide+group_modes(0,2) for TMA store, matching CUTLASS

This commit is contained in:
2026-05-23 01:22:22 +00:00
parent 5efa9c9297
commit e01ff282b7

View File

@@ -467,12 +467,13 @@ class FmhaV3StageCMulti:
)
softmax_all_bar.arrive_and_wait()
# Partition SMEM and GMEM for TMA store
epi_s = cute.select(self.c_smem_s, mode=[0, 1])
# Use the same TMA store pattern as CUTLASS FMHA epilogue warp:
# tma_partition on flat_divide'd GMEM tensor
tCgC_epi = cute.flat_divide(tCgC, self.epi_tile)
tCsC, tCgC_tma = cpasync.tma_partition(
tma_c, 0, cute.make_layout((1,)),
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(gC, 0, 3),
cute.group_modes(tCgC_epi, 0, 2),
)
cute.copy(tma_c, tCsC, tCgC_tma)
cute.arch.cp_async_bulk_commit_group()