fix: epilogue warp reuse mma_corr_cons pipeline instead of creating new one from st
This commit is contained in:
@@ -405,17 +405,15 @@ class FmhaV3StageC2:
|
||||
# ==================== EPILOGUE WARP (10) ====================
|
||||
if warp_idx == self.epilogue_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
# Wait for correction to finish normalizing O
|
||||
epi_handle = corr_epi_cons.wait_and_advance()
|
||||
# Use epilogue_tma_store to write O from TMEM to GMEM
|
||||
# Write O from TMEM to GMEM via epilogue_tma_store
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32)
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_corr_bar.data_ptr(), num_stages=self.mma_corr_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=c_grp, cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
acc_pipe_cons = mma_corr_cons # reuse the MMA→correction consumer pipeline
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe_cons, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
epi_handle.release()
|
||||
tmem.relinquish_alloc_permit()
|
||||
|
||||
Reference in New Issue
Block a user