From db3572bafbc245aafc445d72a99206df882f89fc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 01:40:13 +0000 Subject: [PATCH] fix: correction_epilog with get_tmem_load_op paired atoms, no TMEM round-trip --- tests/unit/test_fmha_v3_stage_c.py | 122 +++++++++++++++++++---------- 1 file changed, 79 insertions(+), 43 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index eacbe5f7..c868ec98 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -130,10 +130,75 @@ class FmhaV3StageCMulti: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + + # Pre-compute paired TMEM load atom for correction_epilog + epi_corr_tile_size = 32 * 8 // self.o_dtype.width # 16 for BF16 + epi_subtile = (self.epi_tile[0], epi_corr_tile_size) + tmem_load_epi_atom = utils.sm100.get_tmem_load_op( + self.pv_mma_tiler, self.c_layout, self.o_dtype, self.acc_dtype, + epi_subtile, use_2cta_instrs=False, + ) + + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile,tmem_load_epi_atom).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + + def _correction_epilog(self, pv_thr, tOtO, scale, sC, tCgC, tma_c, sfw_idx, tidx, warp_idx, epi_tile, tmem_load_epi_atom): + """CUTLASS correction_epilog: read O from TMEM, normalize, convert, write SMEM→GMEM.""" + epi_corr_tile_size = 32 * 8 // self.o_dtype.width # 16 for BF16 + cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) + tOsO = pv_thr.partition_C(sC) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, epi_corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, epi_corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, epi_corr_tile_size))) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_epi_atom, tOtO_i[(None, None), 0]) + smem_copy_atom = utils.sm100.get_smem_store_op( + self.c_layout, self.o_dtype, self.acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) + + thr_tmem_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + + n_corr_tiles = self.pv_mma_tiler[1] // epi_corr_tile_size + for i in range(n_corr_tiles): + tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO[None, 0, 0, i].shape, self.acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtO[None, 0, 0, i], tTMrO) + for j in range(cute.size(tTMrO)): + tTMrO[j] = tTMrO[j] * scale + tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype) + o_vec = tTMrO.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO[None, 0, 0, i]) + + cute.arch.fence_proxy("async.shared", space="cta") + + # TMA store SMEM → GMEM (same pattern as epilogue_tma_store) + tCgC_epi = cute.flat_divide(tCgC, epi_tile) + tCsC, tCgC_tma = cpasync.tma_partition( + tma_c, 0, cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + + epilog_sync_bar = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + epilog_sync_bar.arrive_and_wait() + + c_buffer = 0 + if warp_idx == self.epilogue_warp_id[0]: + cute.copy(tma_c, tCsC[(None, c_buffer)], tCgC_tma[(None, 0, 0, 0, 0, 0, 0)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + epilog_sync_bar.arrive_and_wait() @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile, tmem_load_epi_atom): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: @@ -399,28 +464,12 @@ class FmhaV3StageCMulti: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === Final O normalization: O *= 1/row_sum === - # DIAG: NO-OP TMEM round-trip test — load and store back without modifying + # === Correction epilog: one-way TMEM → SMEM with normalize === inv_row_sum = Float32(1.0) / row_sum - - # SKIP the TMEM round-trip normalize entirely - # Just use epilogue_tma_store to read raw PV from TMEM - # The inv_row_sum normalization will be applied in Python for now - - # Standard epilogue: TMEM → SMEM → GMEM via TMA store. - # O in TMEM is now scaled by 1/row_sum. - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - acc_cons_st = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage + self._correction_epilog( + pv_thr, tOtO0, inv_row_sum, sC, tCgC, tma_c, sfw_idx, + tidx, warp_idx, epi_tile, tmem_load_epi_atom, ) - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_cons_st = utils.gemm.sm100.epilogue_tma_store( - self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, - 0, const_expr(lambda x: x), (0, 0, 0), - acc_cons_st, acc_pipe, c_pipe, - ) - c_pipe.producer_tail() tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) @@ -444,10 +493,6 @@ def test(): attn = torch.softmax(attn_raw, dim=-1) ref = attn @ v.float() - # Compute unnormalized softmax and row_sum for Python-side normalize - attn_unnorm = torch.exp(attn_raw - attn_raw.max(dim=-1, keepdim=True).values) - row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True) - mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) @@ -465,26 +510,17 @@ def test(): torch.cuda.synchronize() out = c[:, :, 0].float() - - # Python-side normalize: out is raw P@V (unnormalized) - # Divide by row_sum to get the correct softmax output - out_normalized = out / row_sum_unnorm - cos_raw = torch.nn.functional.cosine_similarity( - out.flatten().unsqueeze(0), (attn_unnorm @ v.float()).flatten().unsqueeze(0) + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) ).item() - cos_norm = torch.nn.functional.cosine_similarity( - out_normalized.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) - ).item() - max_abs = (out_normalized - ref).abs().max().item() + max_abs = (out - ref).abs().max().item() n_tiles = n // 128 - print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles):', flush=True) - print(f' Raw PV (no normalize) vs unnorm ref: cos {cos_raw:.6f}', flush=True) - print(f' After Python normalize vs softmax ref: cos {cos_norm:.6f} max_abs {max_abs:.4f} ' - f'{"PASS" if cos_norm >= 0.99 else "FAIL"}', flush=True) - if cos_norm < 0.99: - print(f' out_normalized[0,:4]={out_normalized[0,:4].tolist()}') + print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): ' + f'cos {cos:.6f} max_abs {max_abs:.4f} ' + f'{"PASS" if cos >= 0.99 else "FAIL"}') + if cos < 0.99: + print(f' out[0,:4]={out[0,:4].tolist()}') print(f' ref[0,:4]={ref[0,:4].tolist()}') - print(f' row_sum_unnorm[:4]={row_sum_unnorm[:4,0].tolist()}') if __name__ == '__main__':