diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 258c82db..6defea1c 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -522,44 +522,35 @@ class FmhaKernel: if row < Int32(128): sO_acc[row, col] = sO_acc[row, col] * inv_row_sum - # Copy sO_acc (FP32) -> sC (BF16) using SMEM copy - # sC has swizzled layout from compute_epilogue_tile_shape, - # but we can write to it using the epi_tile coordinate mapping. + # ============================================================ + # Cast sO_acc (FP32) -> sC (BF16) and TMA store to GMEM + # ============================================================ + # sC has a swizzled layout. We need to write using sC's native + # coordinate system. The epi_tile defines the logical tile shape. # - # Alternative: use TMA store directly from a properly laid out SMEM buffer. - # The simplest correct approach: use epilogue_tma_store but read from - # a SMEM buffer instead of TMEM. - # - # For the MVP, we use the existing sC layout and write via - # the epi_tile partition that TMA expects. + # Strategy: use epi_s (the TMA-compatible view of sC) to write + # sO_acc data into sC, then TMA copy sC -> gC. + # ============================================================ - # Use epilogue_tma_store to write sO_acc -> GMEM - # But epilogue_tma_store reads from TMEM, not SMEM. - # We need a different TMA store path. - # - # Simplest: use cpasync.bulk_copy (SMEM->GMEM) with sC as source. - # First: copy sO_acc -> sC (FP32->BF16 cast) - # Then: TMA bulk copy sC -> GMEM - # - # Write to sC row by row using the epi_tile coordinate mapping. - # The epi_tile shape is derived from cta_tile_shape_mnk. - # For hd=64 with pv_n_tile=64: epi_tile covers (128, 64). + # epi_s is the 2-mode view of sC that tma_c was created from + epi_s = cute.select(c_smem_s, mode=[0, 1]) + sC_view = cute.make_tensor(sC.iterator, epi_s) # TMA-compatible layout - # For each row assigned to this thread, cast FP32->BF16 - # and write to sC using flat index mapping. - # sC is 2-stage: sC[128, pv_n_tile, num_c_stage] in BF16 - c_stage0 = cute.slice_(sC, (None, None, 0)) # First stage of sC - for col in cutlass.range(0, self.pv_n_tile, unroll=1): - row = sfw_idx - if row < Int32(128): - c_stage0[row, col] = sO_acc[row, col].to(self.o_dtype) + # Write sO_acc -> sC using sC_view's coordinate system + # sC_view is indexed by epi_tile coordinates + # For simple row-major epi_tile: (row, col) works + for row in cutlass.range(0, 128, unroll=1): + for col in cutlass.range(0, self.pv_n_tile, unroll=1): + sC_view[row, col] = sO_acc[row, col].to(self.o_dtype) # TMA store sC -> GMEM cute.arch.fence_proxy("async.shared", space="cta") - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + c_pipe = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + ) c_pipe.producer_acquire() - cute.copy(tma_c, c_stage0, tCgC[(None, None, Int32(0))]) + cute.copy(tma_c, sC_view, tCgC[(None, None, Int32(0))]) c_pipe.producer_commit() c_pipe.producer_tail()