From c9271ffbf4e2d0f1ec253299519bcc291114e728 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 01:23:04 +0000 Subject: [PATCH] fix: index into TMA partitioned tensors for copy --- tests/unit/test_fmha_v3_stage_c.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 9ee32dff..ed446a31 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -467,15 +467,14 @@ class FmhaV3StageCMulti: ) softmax_all_bar.arrive_and_wait() - # Use the same TMA store pattern as CUTLASS FMHA epilogue warp: - # tma_partition on flat_divide'd GMEM tensor + # Use the same TMA store pattern as CUTLASS FMHA epilogue warp tCgC_epi = cute.flat_divide(tCgC, self.epi_tile) tCsC, tCgC_tma = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), cute.group_modes(tCgC_epi, 0, 2), ) - cute.copy(tma_c, tCsC, tCgC_tma) + cute.copy(tma_c, tCsC[(None, 0)], tCgC_tma[(None, 0, 0, 0, 0, 0, 0)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)