From 6fb0e6a417eb7d51d84c81eb81974fcb8a8401ab Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 05:26:50 +0000 Subject: [PATCH] Use sC_flat (non-swizzled epi_s layout) for TMA store from SMEM accumulator --- dsv4/kernels/attention/fmha_smem_acc.py | 29 ++++++++++--------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 39b190e0..d6cf9004 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -224,8 +224,12 @@ class FmhaKernel: if const_expr(self.use_smem_accumulator): sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1)) sO_acc = smem.allocate_tensor(element_type=Float32, layout=sO_acc_layout, byte_alignment=128) + # sC_flat: BF16 SMEM buffer with epi_s layout (non-swizzled) for TMA store + # Used to cast sO_acc (FP32) -> BF16 and TMA store to GMEM + sC_flat = smem.allocate_tensor(element_type=self.o_dtype, layout=cute.select(self.c_smem_s, mode=[0, 1]).outer, byte_alignment=128) else: sO_acc = smem.allocate_tensor(element_type=Float32, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128) + sC_flat = smem.allocate_tensor(element_type=self.o_dtype, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128) gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) @@ -606,40 +610,31 @@ class FmhaKernel: ) c_pipe.producer_tail() else: - # Path 2: write sO_acc (FP32) -> sC -> GMEM via TMA - # Follow the CUTLASS FMHA reference pattern: - # 1. Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing - # 2. Use flat_divide on mC to create gO, then tma_partition, then copy + # Path 2: write sO_acc (FP32) -> sC_flat (BF16) -> TMA store to GMEM + # sC_flat has epi_s layout (same as what tma_c was created from) - # Step 1: Cast sO_acc -> sC (BF16) + # Step 1: Cast sO_acc -> sC_flat (BF16) for row in cutlass.range(0, 128, unroll=1): for col in cutlass.range(0, self.pv_n_tile, unroll=1): val = sO_acc[row, col] if const_expr(self.normalize): inv_row_sum = Float32(1.0) / row_sum val = val * inv_row_sum - sC[(row, col), Int32(0), Int32(0)] = val.to(self.o_dtype) + sC_flat[row, col] = val.to(self.o_dtype) cute.arch.fence_proxy("async.shared", space="cta") - # Step 2: TMA store sC -> GMEM (CUTLASS FMHA reference pattern) - # Create gO from mC via flat_divide + slice (same as CUTLASS reference) + # Step 2: TMA store sC_flat -> GMEM gO_qdl = cute.flat_divide( mC, cute.select(self.pv_mma_tiler, mode=[0, 1]) ) gO = gO_qdl[None, None, None, Int32(0), Int32(0)] - tOsO, tOgO = cpasync.tma_partition( + tOsC, tOgO = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), - cute.group_modes(sC, 0, 2), + cute.group_modes(sC_flat, 0, 2), cute.group_modes(gO, 0, 2), ) - # Wait for all epilogue warps to finish writing to sC - epilog_sync_barrier = pipeline.NamedBarrier( - barrier_id=self.epilog_sync_bar_id, - num_threads=32 * len(self.epilogue_warp_id), - ) - epilog_sync_barrier.arrive_and_wait() if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, tOsO[None, Int32(0)], tOgO[None, Int32(0)]) + cute.copy(tma_c, tOsC[None, Int32(0)], tOgO[None, Int32(0)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)