diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 2a5496d1..a312c24c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -218,19 +218,17 @@ class FmhaKernel: _v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) # TMA + MMA warps for kt in range(self.n_kv_tiles): for k_sub in range(self.n_k_sub_tiles): - # Load Q[k_sub] → sQ + # Load Q[k_sub] → sQ (no TMA barrier, use cp_async wait) cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, 0)]) # Load K[k_sub] → sK cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, 0)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) - # Sync with MMA warp: sQ and sK ready _ksub_bar.arrive_and_wait() # Load V[kt] → sV cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, 0)]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) - # Sync with MMA warp: sV ready for PV _v_bar.arrive_and_wait() else: # Original pipeline path (hd≤256)