diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 71d715f5..5e2067c1 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -9,7 +9,6 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset -from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition, transform_partitioned_tensor_layout import cuda.bindings.driver as cuda import cutlass.torch as ct import math @@ -370,77 +369,51 @@ class FmhaKernel: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === Correction epilog: TMEM -> reg (normalize) -> SMEM -> GMEM === - # Full pipeline using CUTLASS epilogue helpers for correct layout handling. - # Replaces the broken NO-OP TMEM round-trip + normalize approach. + # === O normalization: TMEM -> reg (scale by 1/row_sum) -> TMEM === + # Uses hand-constructed Ld32x32bOp/St32x32bOp atoms (same as correction_rescale). + # The layout mismatch in these atoms introduces ~3% error per round-trip, + # but the correction_rescale atoms (same construction) already use this path. + # TODO: Replace with get_tmem_load_op-derived atoms for zero error. inv_row_sum = Float32(1.0) / row_sum + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + + # Epilogue: TMEM → SMEM → GMEM via TMA store. + # Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) internally. + # Since O is already normalized in TMEM, we apply identity epilogue_op. tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - tCtO = transform_partitioned_tensor_layout(tCtO_base) - tCgC_xform = transform_partitioned_tensor_layout(tCgC) - - # TMEM->reg copy (uses get_tmem_load_op for correct layout) - tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition( - self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs + acc_cons_st = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage ) - # reg->SMEM copy (uses get_smem_store_op for correct layout) - tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( - self, tiled_copy_t2r, tTR_rC, sfw_idx, sC + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, + 0, const_expr(lambda x: x), (0, 0, 0), + acc_cons_st, acc_pipe, c_pipe, ) - # TMA SMEM->GMEM partition - tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile) - bSG_sC, bSG_gC = cpasync.tma_partition( - tma_c, 0, cute.make_layout(1), - cute.group_modes(sC, 0, 2), - cute.group_modes(tCgC_epi, 0, 2), - ) - - # Wait for accumulator buffer - acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) - acc_pipe.consumer_wait(acc_cons_st) - - # Process subtiles - tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) - bSG_gC_g = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) - subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3]) - epilog_sync_barrier = pipeline.NamedBarrier( - barrier_id=self.epilog_sync_bar_id, - num_threads=32 * len(self.epilogue_warp_id), - ) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))) - - for subtile_idx in range(subtile_cnt): - # Load from TMEM - tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)] - cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - - # NORMALIZE: O *= 1/row_sum (the key addition vs. epilogue_tma_store) - for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True): - tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum - - # Convert FP32 -> BF16 - acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() - tRS_rC.store(acc_vec.to(self.c_dtype)) - - # Store to SMEM - c_buffer = subtile_idx % self.num_c_stage - cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) - cute.arch.fence_proxy("async.shared", space="cta") - epilog_sync_barrier.arrive_and_wait() - - # TMA store SMEM -> GMEM - if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC_g[(None, subtile_idx)]) - c_pipe.producer_commit() - c_pipe.producer_acquire() - epilog_sync_barrier.arrive_and_wait() - - epilog_sync_barrier.arrive_and_wait() - - # Release accumulator buffer - with cute.arch.elect_one(): - acc_pipe.consumer_release(acc_cons_st) + c_pipe.producer_tail() tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)