diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index e8afb4e5..2a5496d1 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -181,7 +181,6 @@ class FmhaKernel: tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] - print(f"TMA: tAgQ shape={cute.shape(tAgQ)}, tBgK shape={cute.shape(tBgK)}, tVgV shape={cute.shape(tVgV)}") tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -213,50 +212,104 @@ class FmhaKernel: # ===== TMA LOAD warp ===== if warp_idx == self.tma_warp_id: - qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - qp.tail() - kvp.reset(); pk = kvp.try_acquire() - for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): - kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - pk = cutlass.Boolean(1) - kvp.tail() + if const_expr(self.n_k_sub_tiles > 1): + # K sub-tiling path (hd=512): serialized TMA loads with barrier sync + _ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64) # TMA + MMA warps + _v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) # TMA + MMA warps + for kt in range(self.n_kv_tiles): + for k_sub in range(self.n_k_sub_tiles): + # Load Q[k_sub] → sQ + cute.copy(tma_q, tAgQ[(None, Int32(k_sub))], tAsQ[(None, 0)]) + # Load K[k_sub] → sK + cute.copy(tma_k, tBgK[(None, Int32(k_sub))], tBsK[(None, 0)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + # Sync with MMA warp: sQ and sK ready + _ksub_bar.arrive_and_wait() + # Load V[kt] → sV + cute.copy(tma_v, tVgV[(None, Int32(kt))], tVsV[(None, 0)]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + # Sync with MMA warp: sV ready for PV + _v_bar.arrive_and_wait() + else: + # Original pipeline path (hd≤256) + qp.reset(); qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + qp.tail() + kvp.reset(); pk = kvp.try_acquire() + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): + kvh = kvp.acquire_and_advance(pk) + cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + pk = cutlass.Boolean(1) + kvp.tail() # ===== MMA warp ===== if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() - qc.reset(); qh = qc.wait_and_advance(); qh.release() - kvc.reset(); pk = kvc.try_wait() - acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) - acc_pipe.producer_acquire(acc_st) - for kt in range(self.n_kv_tiles): - kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) - sh = s_prod.acquire_and_advance() - qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - sh.commit() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - if not self.use_smem_p: - # TMEM-P: PV reads P from TMEM - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - # SMEM-P: PV reads P from SMEM - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh.release() - acc_pipe.producer_commit(acc_st); acc_st.advance() - final_o_bar.arrive() - acc_pipe.producer_tail(acc_st) + if const_expr(self.n_k_sub_tiles > 1): + # K sub-tiling path (hd=512): serialized with barrier sync + _ksub_bar = pipeline.NamedBarrier(barrier_id=6, num_threads=64) + _v_bar = pipeline.NamedBarrier(barrier_id=7, num_threads=64) + for kt in range(self.n_kv_tiles): + for k_sub in range(self.n_k_sub_tiles): + # Wait for TMA warp: sQ and sK ready + _ksub_bar.arrive_and_wait() + # QK GEMM: load Q from sQ, K from sK, accumulate into S + qk_mma.set(tcgen05.Field.ACCUMULATE, k_sub != 0) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,0)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + # After all k_sub: S has full QK for this kt + cute.arch.fence_view_async_tmem_store() + # Signal softmax + softmax_done_bar.arrive() # signal that S is ready + # Wait for V load + _v_bar.arrive_and_wait() + # PV GEMM + pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,0)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,0)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + softmax_done_bar.arrive() # signal PV complete for this kt + final_o_bar.arrive() + else: + # Original pipeline path (hd≤256) + qc.reset(); qh = qc.wait_and_advance(); qh.release() + kvc.reset(); pk = kvc.try_wait() + acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + acc_pipe.producer_acquire(acc_st) + for kt in range(self.n_kv_tiles): + kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) + sh = s_prod.acquire_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + sh.commit() + softmax_done_bar.arrive_and_wait() + pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh.release() + acc_pipe.producer_commit(acc_st); acc_st.advance() + final_o_bar.arrive() + acc_pipe.producer_tail(acc_st) # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== if warp_idx < self.mma_warp_id: