fix: index into TMA partitioned tensors for copy
This commit is contained in:
@@ -467,15 +467,14 @@ class FmhaV3StageCMulti:
|
||||
)
|
||||
softmax_all_bar.arrive_and_wait()
|
||||
|
||||
# Use the same TMA store pattern as CUTLASS FMHA epilogue warp:
|
||||
# tma_partition on flat_divide'd GMEM tensor
|
||||
# Use the same TMA store pattern as CUTLASS FMHA epilogue warp
|
||||
tCgC_epi = cute.flat_divide(tCgC, self.epi_tile)
|
||||
tCsC, tCgC_tma = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
cute.copy(tma_c, tCsC, tCgC_tma)
|
||||
cute.copy(tma_c, tCsC[(None, 0)], tCgC_tma[(None, 0, 0, 0, 0, 0, 0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user