fix: add acc_pipe pipeline for epilogue, matching 12w pattern
- Add acc_bar to SS struct - Create acc_pipe (full pipeline) before if blocks - Pass acc_pipe to epilogue_tma_store (needs full pipeline, not participant)
This commit is contained in:
@@ -118,6 +118,7 @@ class FmhaV3StageC2:
|
||||
s_corr_bar: cute.struct.MemRange[cutlass.Int64, self.softmax_corr_stage * 2]
|
||||
mma_corr_bar: cute.struct.MemRange[cutlass.Int64, self.mma_corr_stage * 2]
|
||||
corr_epi_bar: cute.struct.MemRange[cutlass.Int64, self.epi_stage * 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
@@ -132,6 +133,8 @@ class FmhaV3StageC2:
|
||||
mma_corr_prod, mma_corr_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_corr_bar.data_ptr(), num_stages=self.mma_corr_stage, producer_group=cg(1), consumer_group=cg(32 * len(self.correction_warp_ids)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
# Correction → Epilogue: O in SMEM ready
|
||||
corr_epi_prod, corr_epi_cons = pipeline.PipelineAsync.create(barrier_storage=st.corr_epi_bar.data_ptr(), num_stages=self.epi_stage, producer_group=cg(32 * len(self.correction_warp_ids)), consumer_group=cg(32)).make_participants()
|
||||
# Accumulator pipeline for epilogue (full pipeline, not participants)
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
# TMEM alloc barrier: softmax + correction + MMA
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((*self.softmax_warp_ids, *self.correction_warp_ids, self.mma_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.softmax_warp_ids[0], is_two_cta=cute.size(qk_mma.thr_id.shape) == 2, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
@@ -412,7 +415,7 @@ class FmhaV3StageC2:
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32)
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_pipe_cons = mma_corr_cons # reuse the MMA→correction consumer pipeline
|
||||
acc_pipe_cons = acc_pipe # full pipeline for epilogue_tma_store
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe_cons, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
epi_handle.release()
|
||||
|
||||
Reference in New Issue
Block a user