From e93dabe43c5cecb0915fb9fab5a0b015a9815e1d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 04:57:08 +0000 Subject: [PATCH] D1: K sub-tile MMA path using pipeline barriers --- dsv4/kernels/attention/fmha.py | 48 ++++++++++++++-------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a312c24c..07a80476 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -213,23 +213,21 @@ class FmhaKernel: # ===== TMA LOAD warp ===== if warp_idx == self.tma_warp_id: if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): serialized TMA loads with barrier sync - _ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64) # TMA + MMA warps - _v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) # TMA + MMA warps + # K sub-tiling path (hd=512): load Q and K per k_sub using pipeline barriers for kt in range(self.n_kv_tiles): for k_sub in range(self.n_k_sub_tiles): - # Load Q[k_sub] → sQ (no TMA barrier, use cp_async wait) - cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, 0)]) + # Load Q[k_sub] → sQ + qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) # Load K[k_sub] → sK - cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, 0)]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - _ksub_bar.arrive_and_wait() + kvh = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + qh = qp.wait_and_advance(); qh.release() + kvh = kvp.wait_and_advance(); pk = cutlass.Boolean(1) # Load V[kt] → sV - cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, 0)]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) - _v_bar.arrive_and_wait() + # (V doesn't depend on k_sub, load once per kt) + # V is already loaded in the K/V pipeline + kvp.tail() else: # Original pipeline path (hd≤256) qp.reset(); qh = qp.acquire_and_advance() @@ -247,36 +245,30 @@ class FmhaKernel: if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): serialized with barrier sync - _ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64) - _v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) + # K sub-tiling path (hd=512): pipeline sync for k_sub iterations for kt in range(self.n_kv_tiles): for k_sub in range(self.n_k_sub_tiles): - # Wait for TMA warp: sQ and sK ready - _ksub_bar.arrive_and_wait() - # QK GEMM: load Q from sQ, K from sK, accumulate into S + qh = qc.wait_and_advance(); qh.release() + kvh = kvc.wait_and_advance() qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0) for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,0)], tStS0) + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) qk_mma.set(tcgen05.Field.ACCUMULATE, True) # After all k_sub: S has full QK for this kt cute.arch.fence_view_async_tmem_store() - # Signal softmax - softmax_done_bar.arrive() # signal that S is ready - # Wait for V load - _v_bar.arrive_and_wait() - # PV GEMM + softmax_done_bar.arrive() + softmax_done_bar.arrive_and_wait() pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) if not self.use_smem_p: for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,0)], tOtO0) + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) pv_mma.set(tcgen05.Field.ACCUMULATE, True) else: for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,0)], tOtO0) + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) pv_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store() - softmax_done_bar.arrive() # signal PV complete for this kt + kvh.release() final_o_bar.arrive() else: # Original pipeline path (hd≤256)