From e01ff282b7afa0a87ddc52385730018ea2dbc20d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 01:22:22 +0000 Subject: [PATCH] fix: use flat_divide+group_modes(0,2) for TMA store, matching CUTLASS --- tests/unit/test_fmha_v3_stage_c.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 40c446a4..9ee32dff 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -467,12 +467,13 @@ class FmhaV3StageCMulti: ) softmax_all_bar.arrive_and_wait() - # Partition SMEM and GMEM for TMA store - epi_s = cute.select(self.c_smem_s, mode=[0, 1]) + # Use the same TMA store pattern as CUTLASS FMHA epilogue warp: + # tma_partition on flat_divide'd GMEM tensor + tCgC_epi = cute.flat_divide(tCgC, self.epi_tile) tCsC, tCgC_tma = cpasync.tma_partition( - tma_c, 0, cute.make_layout((1,)), + tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), - cute.group_modes(gC, 0, 3), + cute.group_modes(tCgC_epi, 0, 2), ) cute.copy(tma_c, tCsC, tCgC_tma) cute.arch.cp_async_bulk_commit_group()