fix: epilogue warp reuse mma_corr_cons pipeline instead of creating new one from st

This commit is contained in:
2026-05-22 09:56:18 +00:00
parent b149c310ac
commit ad8865eb73

View File

@@ -405,17 +405,15 @@ class FmhaV3StageC2:
# ==================== EPILOGUE WARP (10) ====================
if warp_idx == self.epilogue_warp_id:
tmem.wait_for_alloc()
tmem.allocate(self.num_tmem_alloc_cols)
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
# Wait for correction to finish normalizing O
epi_handle = corr_epi_cons.wait_and_advance()
# Use epilogue_tma_store to write O from TMEM to GMEM
# Write O from TMEM to GMEM via epilogue_tma_store
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_corr_bar.data_ptr(), num_stages=self.mma_corr_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=c_grp, cta_layout_vmnk=cl_vmnk, defer_sync=True)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
acc_pipe_cons = mma_corr_cons # reuse the MMA→correction consumer pipeline
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe_cons, c_pipe)
c_pipe.producer_tail()
epi_handle.release()
tmem.relinquish_alloc_permit()