From 8048aa4be64c965dcbf0a41ba73fe99e570fe8bb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 00:46:52 +0000 Subject: [PATCH] D1.5: Simplify TMA store - use 2D sC_epi and gC_epi views --- dsv4/kernels/attention/fmha.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 4a643645..dabd0aed 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -490,25 +490,19 @@ class FmhaKernel: # Step 6: TMA store SMEM → GMEM # The normalized O is now in sC (written by the correction epilog). - # TMA store from sC to the output tensor in GMEM. - # Use the pre-partitioned tCgC (GMEM partition) and sC (SMEM buffer). - gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler,(None,0,None)),(None,None,None)) - tCgC_epi = cute.flat_divide(tCgC, epi_tile) + # Use the same TMA store pattern as the CUTLASS FMHA reference. + # Partition sC and gC for the bulk TMA copy. + sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA + gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) # 2D output tile bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), - cute.group_modes(sC, 0, 2), - cute.group_modes(tCgC_epi, 0, 2), + sC_epi, gC_epi, ) # One TMA store for the full output tile if warp_idx == self.epilogue_warp_id[0]: - c_pipe = pipeline.PipelineTmaStore.create( - num_stages=self.num_c_stage, - producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), - ) - c_pipe.producer_acquire() - cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, 0)]) - c_pipe.producer_commit() - c_pipe.producer_tail() + cute.copy(tma_c, bSG_sC, bSG_gC) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) # D5a: Write LSE (log-softmax) when normalize=False # lse = ln(row_sum) + row_max * ln(2)