diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 7d5b8594..40c446a4 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -460,7 +460,7 @@ class FmhaV3StageCMulti: cute.arch.fence_proxy("async.shared", space="cta") - # TMA store: SMEM → GMEM (reuse existing tma_c from kernel setup) + # TMA store: SMEM → GMEM # Sync all softmax warps before TMA store softmax_all_bar = pipeline.NamedBarrier( barrier_id=5, num_threads=32 * len(self.epilogue_warp_id) @@ -468,10 +468,11 @@ class FmhaV3StageCMulti: softmax_all_bar.arrive_and_wait() # Partition SMEM and GMEM for TMA store + epi_s = cute.select(self.c_smem_s, mode=[0, 1]) tCsC, tCgC_tma = cpasync.tma_partition( tma_c, 0, cute.make_layout((1,)), cute.group_modes(sC, 0, 2), - cute.group_modes(tCgC, 0, 3), + cute.group_modes(gC, 0, 3), ) cute.copy(tma_c, tCsC, tCgC_tma) cute.arch.cp_async_bulk_commit_group()