fix: index into TMA partitioned tensors for copy

This commit is contained in:
2026-05-23 01:23:04 +00:00
parent e01ff282b7
commit c9271ffbf4

View File

@@ -467,15 +467,14 @@ class FmhaV3StageCMulti:
)
softmax_all_bar.arrive_and_wait()
# Use the same TMA store pattern as CUTLASS FMHA epilogue warp:
# tma_partition on flat_divide'd GMEM tensor
# Use the same TMA store pattern as CUTLASS FMHA epilogue warp
tCgC_epi = cute.flat_divide(tCgC, self.epi_tile)
tCsC, tCgC_tma = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(tCgC_epi, 0, 2),
)
cute.copy(tma_c, tCsC, tCgC_tma)
cute.copy(tma_c, tCsC[(None, 0)], tCgC_tma[(None, 0, 0, 0, 0, 0, 0)])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)