fix: add epilogue warp to tmem_bar, restore wait_for_alloc in epilogue
The epilogue needs tmem_ptr for epilogue_tma_store. It must be part of the tmem alloc barrier to synchronize.
This commit is contained in:
@@ -135,8 +135,8 @@ class FmhaV3StageC2:
|
||||
corr_epi_prod, corr_epi_cons = pipeline.PipelineAsync.create(barrier_storage=st.corr_epi_bar.data_ptr(), num_stages=self.epi_stage, producer_group=cg(32 * len(self.correction_warp_ids)), consumer_group=cg(32)).make_participants()
|
||||
# Accumulator pipeline for epilogue (full pipeline, not participants)
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
# TMEM alloc barrier: softmax + correction + MMA
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((*self.softmax_warp_ids, *self.correction_warp_ids, self.mma_warp_id)))
|
||||
# TMEM alloc barrier: softmax + correction + MMA + epilogue
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((*self.softmax_warp_ids, *self.correction_warp_ids, self.mma_warp_id, self.epilogue_warp_id)))
|
||||
# Softmax done barrier: MMA waits for softmax to produce P before starting PV
|
||||
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 * len(self.softmax_warp_ids) + 32)
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.softmax_warp_ids[0], is_two_cta=cute.size(qk_mma.thr_id.shape) == 2, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
Reference in New Issue
Block a user