diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index ed446a31..e11dcac0 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -130,10 +130,19 @@ class FmhaV3StageCMulti: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + + # Pre-compute epilog atoms for correction_epilog pattern + epi_corr_tile_size = 32 * 8 // self.o_dtype.width # 16 for BF16 + epi_subtile = (self.epi_tile[0], epi_corr_tile_size) + tmem_load_epi_atom = utils.sm100.get_tmem_load_op( + self.pv_mma_tiler, self.c_layout, self.o_dtype, self.acc_dtype, + epi_subtile, use_2cta_instrs=False, + ) + + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile,tmem_load_epi_atom).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile, tmem_load_epi_atom): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: @@ -401,44 +410,29 @@ class FmhaV3StageCMulti: # === Final O normalization + epilogue: CUTLASS correction_epilog pattern === # ONE-WAY trip: TMEM → reg (normalize + FP32→BF16) → SMEM → TMA → GMEM - # NO TMEM round-trip. Hand-constructed atoms corrupt data on round-trip. inv_row_sum = Float32(1.0) / row_sum - # Build paired atoms for TMEM load → SMEM store + # Build paired SMEM store atom from the pre-computed TMEM load atom epi_corr_tile_size = 32 * 8 // self.o_dtype.width # 16 for BF16 - epi_subtile = (self.epi_tile[0], epi_corr_tile_size) - - tOsO = pv_thr.partition_C(sC) - cO_epi = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) - tOcO_epi = pv_thr.partition_C(cO_epi) - tOtO_epi = cute.logical_divide(tOtO0, cute.make_layout((128, epi_corr_tile_size))) - tOsO_epi = cute.logical_divide(tOsO, cute.make_layout((128, epi_corr_tile_size))) - tOcO_epi = cute.logical_divide(tOcO_epi, cute.make_layout((128, epi_corr_tile_size))) - - tmem_load_epi_atom = utils.sm100.get_tmem_load_op( - self.pv_mma_tiler, - self.c_layout, - self.o_dtype, - self.acc_dtype, - epi_subtile, - use_2cta_instrs=False, - ) tiled_tmem_load_epi = tcgen05.make_tmem_copy( tmem_load_epi_atom, tOtO_epi[(None, None), 0] ) - thr_tmem_load_epi = tiled_tmem_load_epi.get_slice(sfw_idx) - smem_store_epi_atom = utils.sm100.get_smem_store_op( - self.c_layout, - self.o_dtype, - self.acc_dtype, - tiled_tmem_load_epi, + self.c_layout, self.o_dtype, self.acc_dtype, tiled_tmem_load_epi, ) tiled_smem_store_epi = cute.make_tiled_copy_D( smem_store_epi_atom, tiled_tmem_load_epi ) + # Partition SMEM for the epilog output + tOsO = pv_thr.partition_C(sC) + cO_epi = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO_epi = pv_thr.partition_C(cO_epi) + tOsO_epi = cute.logical_divide(tOsO, cute.make_layout((128, epi_corr_tile_size))) + tOcO_epi = cute.logical_divide(tOcO_epi, cute.make_layout((128, epi_corr_tile_size))) + + thr_tmem_load_epi = tiled_tmem_load_epi.get_slice(sfw_idx) tTMEM_LOADtO_epi = thr_tmem_load_epi.partition_S(tOtO_epi[(None, None), None]) tTMEM_LOADsO_epi = thr_tmem_load_epi.partition_D(tOsO_epi[(None, None), None]) tTMEM_LOADcO_epi = thr_tmem_load_epi.partition_D(tOcO_epi[(None, None), None]) @@ -461,13 +455,11 @@ class FmhaV3StageCMulti: cute.arch.fence_proxy("async.shared", space="cta") # TMA store: SMEM → GMEM - # Sync all softmax warps before TMA store softmax_all_bar = pipeline.NamedBarrier( barrier_id=5, num_threads=32 * len(self.epilogue_warp_id) ) softmax_all_bar.arrive_and_wait() - # Use the same TMA store pattern as CUTLASS FMHA epilogue warp tCgC_epi = cute.flat_divide(tCgC, self.epi_tile) tCsC, tCgC_tma = cpasync.tma_partition( tma_c, 0, cute.make_layout(1),