From de2028b106b7e32546cfa07fd52fe3e9390bf2ef Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 05:35:56 +0000 Subject: [PATCH] Split sC_flat into staged layout to match TMA atom decomposition --- dsv4/kernels/attention/fmha_smem_acc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 8316878a..1e1d326b 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -628,13 +628,14 @@ class FmhaKernel: # Step 2: TMA store sC_flat -> GMEM # Use tCgC (already partitioned) for the GMEM side of TMA - # Transform tCgC layout (same as epilogue_tma_store) tCgC_xfm = transform_partitioned_tensor_layout(tCgC) tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile) + # sC_flat (128, pv_n_tile) -> split to match TMA stage: (128, pv_n_tile//2, 2) + sC_flat_staged = cute.logical_divide(sC_flat, cute.make_layout((128, self.pv_n_tile // 2, 2), stride=(self.pv_n_tile, 2, 1))) tOsC, tOgO = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), - cute.group_modes(sC_flat, 0, 2), - cute.group_modes(tCgC_epi, 0, 2), + sC_flat_staged, + tCgC_epi, ) if warp_idx == self.epilogue_warp_id[0]: cute.copy(tma_c, tOsC, tOgO)