From ad8865eb731359ace472bce8d46cd1f2a9932db1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 09:56:18 +0000 Subject: [PATCH] fix: epilogue warp reuse mma_corr_cons pipeline instead of creating new one from st --- tests/unit/test_fmha_v3_stage_c2.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c2.py b/tests/unit/test_fmha_v3_stage_c2.py index 9ea606dc..d79ef7e3 100644 --- a/tests/unit/test_fmha_v3_stage_c2.py +++ b/tests/unit/test_fmha_v3_stage_c2.py @@ -405,17 +405,15 @@ class FmhaV3StageC2: # ==================== EPILOGUE WARP (10) ==================== if warp_idx == self.epilogue_warp_id: tmem.wait_for_alloc() - tmem.allocate(self.num_tmem_alloc_cols) tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) - # Wait for correction to finish normalizing O epi_handle = corr_epi_cons.wait_and_advance() - # Use epilogue_tma_store to write O from TMEM to GMEM + # Write O from TMEM to GMEM via epilogue_tma_store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_corr_bar.data_ptr(), num_stages=self.mma_corr_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=c_grp, cta_layout_vmnk=cl_vmnk, defer_sync=True) - acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + acc_pipe_cons = mma_corr_cons # reuse the MMA→correction consumer pipeline + acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe_cons, c_pipe) c_pipe.producer_tail() epi_handle.release() tmem.relinquish_alloc_permit()