From 65e52f5934b7bc1f355f8d767fac325d7be31296 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 10:15:26 +0000 Subject: [PATCH] fix: add softmax_done_bar to synchronize MMA PV with softmax P production MMA must wait for softmax to produce P in TMEM before starting PV. Without this, MMA reads stale P data from TMEM, causing deadlock. softmax_done_bar: softmax warps arrive after P store, MMA waits before PV. --- tests/unit/test_fmha_v3_stage_c2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_fmha_v3_stage_c2.py b/tests/unit/test_fmha_v3_stage_c2.py index 6868be2a..348db5d3 100644 --- a/tests/unit/test_fmha_v3_stage_c2.py +++ b/tests/unit/test_fmha_v3_stage_c2.py @@ -137,6 +137,8 @@ class FmhaV3StageC2: acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True) # TMEM alloc barrier: softmax + correction + MMA tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((*self.softmax_warp_ids, *self.correction_warp_ids, self.mma_warp_id))) + # Softmax done barrier: MMA waits for softmax to produce P before starting PV + softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 * len(self.softmax_warp_ids) + 32) tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.softmax_warp_ids[0], is_two_cta=cute.size(qk_mma.thr_id.shape) == 2, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) if warp_idx == self.empty_warp_id: cute.arch.mbarrier_init(st.tmem_dealloc, 32 * len((*self.softmax_warp_ids, *self.correction_warp_ids))) @@ -208,7 +210,8 @@ class FmhaV3StageC2: cute.gemm(qk_mma, tStS0, tCrQ[(None, None, kb, 0)], tCrK[(None, None, kb, kh.index)], tStS0) qk_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store(); sh.commit(); kh.release() - # PV -> O (softmax consumes S and produces P between these two) + # PV -> O (wait for softmax to produce P) + softmax_done_bar.arrive_and_wait() vh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) oh = mma_corr_prod.acquire_and_advance() pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) @@ -308,6 +311,7 @@ class FmhaV3StageC2: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() si_handle.release() + softmax_done_bar.arrive() vec_handle = s_corr_prod.acquire_and_advance() # Final vec = [row_sum, row_max]