diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index 7bed17df..5f42e99a 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -592,7 +592,7 @@ class FmhaKernel: # ============================================================ if const_expr(not self.use_smem_accumulator): - # Path 1: epilogue_tma_store (reads O from TMEM) + # Path 1: epilogue_tma_store (reads O from TMEM, proven for n_kv=1) tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) @@ -606,39 +606,24 @@ class FmhaKernel: ) c_pipe.producer_tail() else: - # Path 2: sO_acc -> sC -> TMA store to GMEM - # Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing. - # sC layout from make_smem_layout_epi: ((M,N), ?, num_c_stage, ...). - # Write via sC's native coordinate system. - for row in cutlass.range(0, 128, unroll=1): - for col in cutlass.range(0, self.pv_n_tile, unroll=1): - sC[(row, col), Int32(0), Int32(0)] = sO_acc[row, col].to(self.o_dtype) - - # TMA store sC -> GMEM using cpasync.tma_partition - cute.arch.fence_proxy("async.shared", space="cta") - epilog_sync_barrier = pipeline.NamedBarrier( - barrier_id=self.epilog_sync_bar_id, - num_threads=32 * len(self.epilogue_warp_id), - ) - epilog_sync_barrier.arrive_and_wait() - tCgC_xfm = transform_partitioned_tensor_layout(tCgC) - tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile) - bSG_sC, bSG_gC = cpasync.tma_partition( - tma_c, 0, cute.make_layout(1), - cute.group_modes(sC, 0, 2), - cute.group_modes(tCgC_epi, 0, 2), - ) - # Slice off MMA tile coordinates (same as epilogue_tma_store) - bSG_gC = bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))] - c_pipe = pipeline.PipelineTmaStore.create( - num_stages=self.num_c_stage, - producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - ) - c_pipe.producer_acquire() - if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC[(None, Int32(0))], bSG_gC[(None, None, Int32(0))]) - c_pipe.producer_commit() - c_pipe.producer_tail() + # Path 2: write sO_acc (FP32) -> GMEM directly from registers + # Each thread handles one row (sfw_idx) of the output. + # Read sO_acc row, normalize, cast to BF16, write to GMEM. + # + # For GMEM writes, use gC (the local_tile of the output tensor). + # gC is created from mC with TMA-compatible layout. + # We write to it using cute.copy with a universal copy atom. + # + # Actually, the simplest approach: write each element directly + # to the GMEM output tensor using scalar stores. + # gC[sfw_idx, col, 0] = BF16(sO_acc[sfw_idx, col] / row_sum) + for col in cutlass.range(0, self.pv_n_tile, unroll=1): + row = sfw_idx + if row < Int32(128): + val = sO_acc[row, col] + if const_expr(not self.normalize): + val = val / row_sum + gC[Int32(row), Int32(col), Int32(0)] = val.to(self.o_dtype) # Compute LSE: lse = ln(row_sum) + row_max * ln(2) # Only when emitting un-normalized output (D5a path).