From e614d0894c97eebfe40572da877eddf7989169e2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 04:57:40 +0000 Subject: [PATCH] Clean up SMEM acc epilogue: flat indexing sO_acc->sC, TMA store from sC_s0 --- dsv4/kernels/attention/fmha_smem_acc.py | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 6defea1c..d1f51e25 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -523,25 +523,25 @@ class FmhaKernel: sO_acc[row, col] = sO_acc[row, col] * inv_row_sum # ============================================================ - # Cast sO_acc (FP32) -> sC (BF16) and TMA store to GMEM + # ============================================================ - # sC has a swizzled layout. We need to write using sC's native - # coordinate system. The epi_tile defines the logical tile shape. - # - # Strategy: use epi_s (the TMA-compatible view of sC) to write - # sO_acc data into sC, then TMA copy sC -> gC. + # EPILOGUE: write sO_acc to GMEM + # ============================================================ + # Strategy: write sO_acc (FP32) -> sC (BF16) -> TMA store to GMEM. + # The sC write uses flat indexing over the stage-0 slice. + # The TMA store uses the epi_s layout that tma_c was created from. # ============================================================ - # epi_s is the 2-mode view of sC that tma_c was created from - epi_s = cute.select(c_smem_s, mode=[0, 1]) - sC_view = cute.make_tensor(sC.iterator, epi_s) # TMA-compatible layout - - # Write sO_acc -> sC using sC_view's coordinate system - # sC_view is indexed by epi_tile coordinates - # For simple row-major epi_tile: (row, col) works - for row in cutlass.range(0, 128, unroll=1): - for col in cutlass.range(0, self.pv_n_tile, unroll=1): - sC_view[row, col] = sO_acc[row, col].to(self.o_dtype) + # Cast sO_acc (FP32) -> sC (BF16), stage 0 + # sC layout from make_smem_layout_epi: complex swizzled layout. + # We write via flat indexing on the stage-0 slice. + # The layout will handle the swizzle automatically. + sC_s0 = sC[(None, None, Int32(0))] + for i in cutlass.range(0, cute.size(sC_s0), unroll=1): + row = i // self.pv_n_tile + col = i % self.pv_n_tile + if row < Int32(128) and col < self.pv_n_tile: + sC_s0[i] = sO_acc[row, col].to(self.o_dtype) # TMA store sC -> GMEM cute.arch.fence_proxy("async.shared", space="cta") @@ -550,9 +550,9 @@ class FmhaKernel: producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) ) c_pipe.producer_acquire() - cute.copy(tma_c, sC_view, tCgC[(None, None, Int32(0))]) + cute.copy(tma_c, sC_s0, tCgC[(None, None, Int32(0))]) c_pipe.producer_commit() c_pipe.producer_tail() tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) \ No newline at end of file + tmem.free(tmem_ptr)