diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index d1f51e25..0233fd4c 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -525,32 +525,43 @@ class FmhaKernel: # ============================================================ # ============================================================ - # EPILOGUE: write sO_acc to GMEM + # EPILOGUE: write sO_acc to sC, TMA store sC -> GMEM # ============================================================ - # Strategy: write sO_acc (FP32) -> sC (BF16) -> TMA store to GMEM. - # The sC write uses flat indexing over the stage-0 slice. - # The TMA store uses the epi_s layout that tma_c was created from. + # Write sO_acc (FP32) -> sC (BF16) using sC's layout indexing. + # Then use cpasync.tma_partition + cute.copy for TMA store. # ============================================================ - # Cast sO_acc (FP32) -> sC (BF16), stage 0 - # sC layout from make_smem_layout_epi: complex swizzled layout. - # We write via flat indexing on the stage-0 slice. - # The layout will handle the swizzle automatically. - sC_s0 = sC[(None, None, Int32(0))] - for i in cutlass.range(0, cute.size(sC_s0), unroll=1): - row = i // self.pv_n_tile - col = i % self.pv_n_tile - if row < Int32(128) and col < self.pv_n_tile: - sC_s0[i] = sO_acc[row, col].to(self.o_dtype) + # Step 1: Write sO_acc -> sC (BF16) + # Use flat indexing on sC with stage dimension removed. + # sC has layout from make_smem_layout_epi: ((M, N), ?, num_c_stage, ...). + # We write to stage 0 using the epi_s (2-mode) view. + # Since we can't easily create a tensor with epi_s layout from sC's pointer + # (swizzle conflict), we write to sC via its native 4D layout. + # + # sC shape: ((128, pv_n_tile), 1, num_c_stage, ...) + # Index: sC[(row, col), 0, stage_idx, ...] + for row in cutlass.range(0, 128, unroll=1): + for col in cutlass.range(0, self.pv_n_tile, unroll=1): + sC[(row, col), Int32(0), Int32(0)] = sO_acc[row, col].to(self.o_dtype) - # TMA store sC -> GMEM + # Step 2: TMA store sC -> GMEM + # Use cpasync.tma_partition (same as epilogue_tma_store) cute.arch.fence_proxy("async.shared", space="cta") c_pipe = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) ) c_pipe.producer_acquire() - cute.copy(tma_c, sC_s0, tCgC[(None, None, Int32(0))]) + # Transform tCgC layout (same as epilogue_tma_store) + tCgC = transform_partitioned_tensor_layout(tCgC) + tCgC_epi = cute.flat_divide(tCgC, epi_tile) + # Create TMA partition from sC and gC + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_c, 0, cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + cute.copy(tma_c, bSG_sC[None, ...], bSG_gC[None, ...]) c_pipe.producer_commit() c_pipe.producer_tail()