diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 794580d9..2402a8b8 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -213,20 +213,22 @@ class FmhaKernel: # ===== TMA LOAD warp ===== if warp_idx == self.tma_warp_id: if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): load Q and K per k_sub + # K sub-tiling path (hd=512): unrolled k_sub loads qp.reset() kvp.reset() - for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): - # Load Q[k_sub] → sQ - qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - # Load K[k_sub] → sK - kvh = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - # Load V[kt] → sV - kvh = kvp.acquire_and_advance() - cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + # k_sub=0: Load Q[0] and K[0] + qh0 = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh0.index)], tma_bar_ptr=qh0.barrier) + kvh0 = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, Int32(0))], tBsK[(None, kvh0.index)], tma_bar_ptr=kvh0.barrier) + # k_sub=1: Load Q[1] and K[1] + qh1 = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(1))], tAsQ[(None, qh1.index)], tma_bar_ptr=qh1.barrier) + kvh1 = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, Int32(1))], tBsK[(None, kvh1.index)], tma_bar_ptr=kvh1.barrier) + # Load V[0] + kvh_v = kvp.acquire_and_advance() + cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier) qp.tail() kvp.tail() else: @@ -246,34 +248,40 @@ class FmhaKernel: if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd=512): pipeline sync for k_sub iterations - kvh = kvc.wait_and_advance() # initialize kvh before loops - for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): - qh = qc.wait_and_advance(); qh.release() - kvh = kvc.wait_and_advance() - qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - kvh.release() - # After all k_sub: S has full QK for this kt - cute.arch.fence_view_async_tmem_store() - softmax_done_bar.arrive() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - # Load V: consume from K/V pipeline - kvh = kvc.wait_and_advance() - if not self.use_smem_p: - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh.release() + # K sub-tiling path (hd=512): unrolled k_sub iterations + # k_sub=0: QK GEMM with ACCUMULATE=False + qh0 = qc.wait_and_advance(); qh0.release() + kvh0 = kvc.wait_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh0.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh0.release() + # k_sub=1: QK GEMM with ACCUMULATE=True + qh1 = qc.wait_and_advance(); qh1.release() + kvh1 = kvc.wait_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh1.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh1.release() + # After both k_sub: S has full QK for this kt + cute.arch.fence_view_async_tmem_store() + softmax_done_bar.arrive() + softmax_done_bar.arrive_and_wait() + pv_mma.set(tcgen05.Field.ACCUMULATE, False) + # Load V: consume from K/V pipeline + kvh_v = kvc.wait_and_advance() + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh_v.release() final_o_bar.arrive() else: # Original pipeline path (hd≤256)