diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 9ee32dff..ed446a31 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -467,15 +467,14 @@ class FmhaV3StageCMulti: ) softmax_all_bar.arrive_and_wait() - # Use the same TMA store pattern as CUTLASS FMHA epilogue warp: - # tma_partition on flat_divide'd GMEM tensor + # Use the same TMA store pattern as CUTLASS FMHA epilogue warp tCgC_epi = cute.flat_divide(tCgC, self.epi_tile) tCsC, tCgC_tma = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), cute.group_modes(tCgC_epi, 0, 2), ) - cute.copy(tma_c, tCsC, tCgC_tma) + cute.copy(tma_c, tCsC[(None, 0)], tCgC_tma[(None, 0, 0, 0, 0, 0, 0)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)