diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index a0a4e30e..f64c6921 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -66,6 +66,8 @@ class FmhaKernel: self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + # SMEM accumulator: needed when n_kv_tiles > 1 (TMEM round-trip is broken) + self.use_smem_accumulator = self.n_kv_tiles > 1 def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) @@ -217,6 +219,14 @@ class FmhaKernel: _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) + # SMEM accumulator: FP32 [128, pv_n_tile] row-major + # Only needed for n_kv_tiles > 1 (TMEM round-trip is broken) + if const_expr(self.use_smem_accumulator): + sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1)) + sO_acc = smem.allocate_tensor(element_type=Float32, layout=sO_acc_layout, byte_alignment=128) + else: + sO_acc = smem.allocate_tensor(element_type=Float32, layout=cute.make_layout((1, 1), stride=(1, 1)), byte_alignment=128) + gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) @@ -342,7 +352,7 @@ class FmhaKernel: cute.arch.fence_view_async_tmem_store() sh.commit() softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) + pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0 and not self.use_smem_accumulator) if not self.use_smem_p: for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) @@ -405,6 +415,21 @@ class FmhaKernel: # This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192)) _sP_nostage = sP[(None, None, None, 0)] # remove stage dim + # O load atoms: one-way TMEM->REGS read of O_kt after PV (for SMEM accumulator) + # Only needed when use_smem_accumulator (n_kv_tiles > 1) + if const_expr(self.use_smem_accumulator): + o_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_o_load = tcgen05.make_tmem_copy(o_load_atom, tOtO0) + thr_o_load = tiled_o_load.get_slice(sfw_idx) + tTMEM_LOADtO = thr_o_load.partition_S(tOtO0) + cO = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) + tTMEM_LOADcO = thr_o_load.partition_D(tOcO) + + # Zero-initialize sO_acc + for i in cutlass.range(0, cute.size(sO_acc), unroll=1): + sO_acc[i] = Float32(0.0) + row_max = -Float32.inf row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) @@ -539,37 +564,79 @@ class FmhaKernel: si_handle.release() softmax_done_bar.arrive() + # --- SMEM accumulator: load O_kt from TMEM, accumulate in sO_acc --- + if const_expr(self.use_smem_accumulator): + pv_done_bar.arrive_and_wait() + rO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype) + cute.copy(tiled_o_load, tTMEM_LOADtO, rO) + cute.arch.fence_view_async_tmem_load() + + # sO_acc = acc_scale * sO_acc + O_kt + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcO[(j0, 0), j1, 0, 0] + row = coord[0] + col = coord[1] + old_val = sO_acc[row, col] + new_val = acc_scale * old_val + rO[(j0, 0), j1, 0, 0] + sO_acc[row, col] = new_val + # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() # ============================================================ - # EPILOGUE: TMA store O to GMEM + compute LSE + # EPILOGUE: store O to GMEM + compute LSE # ============================================================ - # The raw un-normalized O in TMEM is perfect (cos 0.999998). - # We use epilogue_tma_store which reads O from TMEM directly via - # the correct get_tmem_load_op layout — no round-trip needed. - # - # For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures - # O is correctly rescaled before this epilogue reads it. - # - # External normalization (D5a path): kernel outputs un-normalized O + - # LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum. - # This is exact and composes with D5c sink bias merge. + # For n_kv_tiles=1: O is in TMEM (from PV). Use epilogue_tma_store. + # For n_kv_tiles>1: O is in sO_acc (SMEM). Write to sC, then TMA store. # ============================================================ - # TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM) - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_cons_st = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) - acc_cons_st = utils.gemm.sm100.epilogue_tma_store( - self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, - 0, const_expr(lambda x: x), (0, 0, 0), - acc_cons_st, acc_pipe, c_pipe, - ) - c_pipe.producer_tail() + if const_expr(not self.use_smem_accumulator): + # Path 1: epilogue_tma_store (reads O from TMEM) + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, + 0, const_expr(lambda x: x), (0, 0, 0), + acc_cons_st, acc_pipe, c_pipe, + ) + c_pipe.producer_tail() + else: + # Path 2: sO_acc -> sC -> TMA store to GMEM + # Cast sO_acc (FP32) -> sC (BF16) using sC's layout indexing. + # sC layout from make_smem_layout_epi: ((M,N), ?, num_c_stage, ...). + # Write via sC's native coordinate system. + for row in cutlass.range(0, 128, unroll=1): + for col in cutlass.range(0, self.pv_n_tile, unroll=1): + sC[(row, col), Int32(0), Int32(0)] = sO_acc[row, col].to(self.o_dtype) + + # TMA store sC -> GMEM using cpasync.tma_partition + cute.arch.fence_proxy("async.shared", space="cta") + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + epilog_sync_barrier.arrive_and_wait() + tCgC_xfm = transform_partitioned_tensor_layout(tCgC) + tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_c, 0, cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + c_pipe = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + ) + c_pipe.producer_acquire() + if warp_idx == self.epilogue_warp_id[0]: + cute.copy(tma_c, bSG_sC[(Int32(0),)], bSG_gC[(Int32(0),)]) + c_pipe.producer_commit() + c_pipe.producer_tail() # Compute LSE: lse = ln(row_sum) + row_max * ln(2) # Only when emitting un-normalized output (D5a path).