diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 607072ad..71d715f5 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -370,38 +370,52 @@ class FmhaKernel: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === Correction epilog: one-way TMEM -> reg -> SMEM -> GMEM === - # Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) for correct TMEM read. - # Uses epilogue_smem_copy_and_partition (get_smem_store_op) for correct SMEM write. - # No TMEM round-trip. No layout mismatch. No 3% error. + # === Correction epilog: TMEM -> reg (normalize) -> SMEM -> GMEM === + # Full pipeline using CUTLASS epilogue helpers for correct layout handling. + # Replaces the broken NO-OP TMEM round-trip + normalize approach. inv_row_sum = Float32(1.0) / row_sum - # Set up the TMEM→reg and reg→SMEM copy atoms using CUTLASS helpers tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - # Transform layout: ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, STAGE) - # -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), STAGE) tCtO = transform_partitioned_tensor_layout(tCtO_base) - # Transform gC layout similarly (needed by the helpers) tCgC_xform = transform_partitioned_tensor_layout(tCgC) + + # TMEM->reg copy (uses get_tmem_load_op for correct layout) tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition( self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs ) + # reg->SMEM copy (uses get_smem_store_op for correct layout) tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( self, tiled_copy_t2r, tTR_rC, sfw_idx, sC ) + # TMA SMEM->GMEM partition + tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_c, 0, cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) # Wait for accumulator buffer - acc_pipe.consumer_wait(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + acc_pipe.consumer_wait(acc_cons_st) - # Process each subtile: TMEM load -> normalize -> BF16 convert -> SMEM store + # Process subtiles tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC_g = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3]) + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))) + for subtile_idx in range(subtile_cnt): + # Load from TMEM tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)] cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - # Normalize: O *= 1/row_sum + # NORMALIZE: O *= 1/row_sum (the key addition vs. epilogue_tma_store) for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True): tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum @@ -413,28 +427,20 @@ class FmhaKernel: c_buffer = subtile_idx % self.num_c_stage cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) cute.arch.fence_proxy("async.shared", space="cta") + epilog_sync_barrier.arrive_and_wait() - # TMA store from SMEM to GMEM - # Partition sC and gC for TMA store (using transformed gC) - tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile) - bSG_sC, bSG_gC = cpasync.tma_partition( - tma_c, 0, cute.make_layout(1), - cute.group_modes(sC, 0, 2), - cute.group_modes(tCgC_epi, 0, 2), - ) - # Only warp 0 of epilogue issues TMA store + # TMA store SMEM -> GMEM if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)]) - # Sync after TMA store - epilog_sync_bar = pipeline.NamedBarrier( - barrier_id=self.epilog_sync_bar_id, - num_threads=32 * len(self.epilogue_warp_id), - ) - epilog_sync_bar.arrive_and_wait() + cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC_g[(None, subtile_idx)]) + c_pipe.producer_commit() + c_pipe.producer_acquire() + epilog_sync_barrier.arrive_and_wait() + + epilog_sync_barrier.arrive_and_wait() # Release accumulator buffer with cute.arch.elect_one(): - acc_pipe.consumer_release(pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)) + acc_pipe.consumer_release(acc_cons_st) tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)