diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py index a899d98d..258c82db 100644 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -1,25 +1,13 @@ -"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). +"""FMHA kernel: SMEM accumulator approach for multi-KV-tile O rescale. -SMEM accumulator approach for multi-KV-tile O rescale. -Instead of TMEM round-trip (which corrupts data), we move O from TMEM -to SMEM after each PV GEMM via one-way epilogue, and accumulate in SMEM. - -This avoids the D1.5 TMEM round-trip bug entirely. +TMEM round-trip is FUNDAMENTALLY BROKEN (Ld32x32bOp/St32x32bOp column +mapping mismatch, even NO-OP corrupts). This kernel avoids it entirely. Architecture: -- 6-warp specialization: 4 softmax+epilogue, 1 MMA, 1 TMA -- After PV[kt]: one-way TMEM→REGS→SMEM with acc_scale multiplication -- SMEM accumulator persists across kt iterations -- Final TMA store: SMEM→GMEM - -Per-kt flow: -1. Softmax warps: compute P[kt], acc_scale[kt] -2. Signal softmax_done_bar -3. MMA warp: PV[kt] GEMM (ACCUMULATE=False, fresh TMEM) -4. Signal pv_done_bar -5. Softmax/epilogue warps: TMEM→REGS, acc_scale*O_acc + O_kt, REGS→SMEM -6. Repeat for next kt -7. After all kt: SMEM→GMEM via TMA +- 6-warp: 4 softmax+epilogue, 1 MMA, 1 TMA +- PV always ACCUMULATE=False (fresh TMEM each kt) +- After pv_done_bar: one-way TMEM->REGS load O_kt, accumulate in SMEM +- Final: normalize sO_acc -> sC (BF16) -> TMA store to GMEM """ import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 @@ -59,8 +47,17 @@ class FmhaKernel: self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.k_tile = min(head_dim, 256) + self.n_k_sub_tiles = head_dim // self.k_tile + self.kv_stage = 1 if head_dim > 128 else 2 + self.q_stage = 1 + self.num_c_stage = 1 if head_dim > 256 else 2 self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + self.smem_acc_dtype = Float32 def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) @@ -76,7 +73,7 @@ class FmhaKernel: self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) - self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, self.num_c_stage) self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) @@ -124,10 +121,11 @@ class FmhaKernel: ), ) self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() - qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM) + self.c_layout = LayoutEnum.from_tensor(c) + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,self.pv_n_tile), pv_source) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) self._setup(qk_mma, pv_mma) q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) @@ -148,4 +146,422 @@ class FmhaKernel: self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) - # ... rest of kernel to be implemented + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx,_,_ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] + kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] + s_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] + tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() + softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) + final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) + pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id)) + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) + if const_expr(self.use_smem_p): + _p_layout = p_smem_s.outer + _p_swizzle = p_smem_s.inner + else: + _p_layout = cute.make_layout(((1,1),1,(1,1),1)) + _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) + + # SMEM accumulator: FP32 [128, pv_n_tile] row-major + sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1)) + sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128) + + # Zero-initialize sO_acc + if warp_idx < self.mma_warp_id: + for i in cutlass.range(0, cute.size(sO_acc), unroll=1): + sO_acc[i] = Float32(0.0) + + gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) + gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) + gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) + gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) + n_kv_tiles = cute.size(gK, mode=[3]) + + qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) + tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) + tOrP = tOrP_base[(None,None,None,0)] + tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP) + if const_expr(self.tOrP0_offset > 0): + tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) + else: + tOrP0 = tOrP + + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ===== TMA LOAD warp ===== + if warp_idx == self.tma_warp_id: + if const_expr(self.n_k_sub_tiles > 1): + qp.reset(); kvp.reset() + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + kvh = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + kvh_v = kvp.acquire_and_advance() + cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier) + qp.tail(); kvp.tail() + else: + qp.reset(); qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + qp.tail() + kvp.reset(); pk = kvp.try_acquire() + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): + kvh = kvp.acquire_and_advance(pk) + cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + pk = cutlass.Boolean(1) + kvp.tail() + + # ===== MMA warp ===== + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + if const_expr(self.n_k_sub_tiles > 1): + qc.reset(); kvc.reset() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qc.wait_and_advance(); qh.release() + kvh = kvc.wait_and_advance() + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh.release() + cute.arch.fence_view_async_tmem_store() + softmax_done_bar.arrive() + softmax_done_bar.arrive_and_wait() + # PV: ACCUMULATE=False for SMEM accumulator + pv_mma.set(tcgen05.Field.ACCUMULATE, False) + kvh_v = kvc.wait_and_advance() + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh_v.release() + pv_done_bar.arrive() + final_o_bar.arrive() + else: + qc.reset(); qh = qc.wait_and_advance(); qh.release() + kvc.reset(); pk = kvc.try_wait() + acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + acc_pipe.producer_acquire(acc_st) + for kt in range(self.n_kv_tiles): + kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) + sh = s_prod.acquire_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + sh.commit() + softmax_done_bar.arrive_and_wait() + # PV: ACCUMULATE=False for SMEM accumulator + pv_mma.set(tcgen05.Field.ACCUMULATE, False) + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh.release() + pv_done_bar.arrive() + acc_pipe.producer_commit(acc_st); acc_st.advance() + final_o_bar.arrive() + acc_pipe.producer_tail(acc_st) + + # ===== SOFTMAX + SMEM ACCUMULATOR EPILOGUE warps ===== + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) + + # S load atoms + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + + # O load atoms: one-way TMEM->REGS read of O after PV + # Uses same Ld32x32bOp pattern from O's TMEM offset (tOtO0) + o_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_o_load = tcgen05.make_tmem_copy(o_load_atom, tOtO0) + thr_o_load = tiled_o_load.get_slice(sfw_idx) + tTMEM_LOADtO = thr_o_load.partition_S(tOtO0) + # Coordinate tensor for O: maps register positions to (row, col) + cO = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO = pv_thr.partition_C(cO) + tTMEM_LOADcO = thr_o_load.partition_D(tOcO) + + # P store atoms (TMEM-P path) + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) + tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) + thr_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtP = thr_store.partition_D(tStP0) + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tTMEM_STOREcP = thr_store.partition_S(tScP) + + _sP_nostage = sP[(None, None, None, 0)] + + row_max = -Float32.inf + row_sum = Float32(0.0) + scale_log2 = Float32(self.scale_softmax_log2) + + # ============================================================ + # MAIN LOOP: softmax + SMEM accumulator + # ============================================================ + for kt in range(self.n_kv_tiles): + si_handle = s_cons.wait_and_advance() + + # --- Load S from TMEM --- + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) + cute.arch.fence_view_async_tmem_load() + + # D3/D4/D5c: logit modification + if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias): + kt_offset = Int32(kt * 128) + sink_val = Float32(0.0) + if const_expr(self.apply_sink_bias): + sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax) + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0]; k_coord = coord[1] + kv_pos = kt_offset + k_coord + if const_expr(self.apply_sink_bias): + if kv_pos >= Int32(self.n_comp): + tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val + should_mask = Boolean(0) + if const_expr(self.apply_swa_mask): + if kv_pos >= Int32(self.n_comp) + swa_len: + should_mask = Boolean(1) + if const_expr(self.is_causal): + if kv_pos >= Int32(self.n_comp): + swa_pos = kv_pos - Int32(self.n_comp) + if swa_pos > m_coord: + should_mask = Boolean(1) + if should_mask: + tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf + + # --- Online softmax --- + old_row_max = row_max + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) + + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = Float32(0.0) + + acc_scale_ = old_row_max - row_max_safe + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if old_row_max == -cutlass.Float32.inf: + acc_scale = Float32(0.0) + row_sum *= acc_scale + + # --- Compute P = softmax(S) --- + rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) + rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) + minus_row_max = Float32(0.0) - row_max_safe + + rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max + tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] + s_vec = tTMEM_LOADrS_frg[None, j].load() + rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) + + # --- Store P to TMEM or SMEM --- + if not self.use_smem_p: + cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + else: + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0]; k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + cute.arch.fence_proxy("async.shared", space="cta") + + si_handle.release() + softmax_done_bar.arrive() + + # --- Wait for PV[kt] to complete --- + pv_done_bar.arrive_and_wait() + + # ======================================================== + # SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM + # ======================================================== + # O_kt is in TMEM (PV with ACCUMULATE=False → fresh output) + # Load via one-way Ld32x32bOp (read-only, NO write-back to TMEM) + # Then: sO_acc = acc_scale * sO_acc + O_kt + # Using coordinate-indexed writes to sO_acc + # ======================================================== + rO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype) + cute.copy(tiled_o_load, tTMEM_LOADtO, rO) + cute.arch.fence_view_async_tmem_load() + + # Rescale existing sO_acc and add O_kt + # Use coordinate tensor to map each register to (row, col) in sO_acc + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcO[(j0, 0), j1, 0, 0] + row = coord[0] + col = coord[1] + old_val = sO_acc[row, col] + new_val = acc_scale * old_val + rO[(j0, 0), j1, 0, 0] + sO_acc[row, col] = new_val + + # Wait for MMA's final signal + final_o_bar.arrive_and_wait() + + # ============================================================ + # EPILOGUE: normalize sO_acc, cast to BF16, TMA store to GMEM + # ============================================================ + # sO_acc has the un-normalized O accumulated across all kt. + # Normalize: O_norm = O_unnorm / row_sum + # Then cast to BF16 and write to sC for TMA store. + # ============================================================ + + # Compute LSE and row_sum output + if const_expr(not self.normalize): + _row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + _row_max_safe = Float32(0.0) + _ln2 = Float32(0.6931471805599453) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 + mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val + mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum + + # Normalize and cast sO_acc -> sC + # Each thread handles its rows (sfw_idx maps to rows in sO_acc) + # sO_acc is (128, pv_n_tile), sC layout may differ + # Use coordinate-based write to sC via epi_tile + # + # For TMA store, we need data in sC in the layout expected by tma_c. + # We can't easily do coordinate-indexed writes to sC (swizzled layout). + # Instead: normalize in sO_acc, then bulk-copy to sC via SMEM copy. + # + # Simpler approach for n_kv_tiles=1 compatibility: + # For n_kv_tiles=1, we can use the existing epilogue_tma_store path. + # For n_kv_tiles>1, we use the sO_acc -> sC -> TMA path. + # + # For now: normalize sO_acc in-place, then copy to sC (BF16), then TMA store. + if const_expr(self.normalize): + # Normalize: divide by row_sum + # Each of the 128 softmax threads handles one row + inv_row_sum = Float32(1.0) / row_sum + for col in cutlass.range(0, self.pv_n_tile, unroll=1): + row = sfw_idx + if row < Int32(128): + sO_acc[row, col] = sO_acc[row, col] * inv_row_sum + + # Copy sO_acc (FP32) -> sC (BF16) using SMEM copy + # sC has swizzled layout from compute_epilogue_tile_shape, + # but we can write to it using the epi_tile coordinate mapping. + # + # Alternative: use TMA store directly from a properly laid out SMEM buffer. + # The simplest correct approach: use epilogue_tma_store but read from + # a SMEM buffer instead of TMEM. + # + # For the MVP, we use the existing sC layout and write via + # the epi_tile partition that TMA expects. + + # Use epilogue_tma_store to write sO_acc -> GMEM + # But epilogue_tma_store reads from TMEM, not SMEM. + # We need a different TMA store path. + # + # Simplest: use cpasync.bulk_copy (SMEM->GMEM) with sC as source. + # First: copy sO_acc -> sC (FP32->BF16 cast) + # Then: TMA bulk copy sC -> GMEM + # + # Write to sC row by row using the epi_tile coordinate mapping. + # The epi_tile shape is derived from cta_tile_shape_mnk. + # For hd=64 with pv_n_tile=64: epi_tile covers (128, 64). + + # For each row assigned to this thread, cast FP32->BF16 + # and write to sC using flat index mapping. + # sC is 2-stage: sC[128, pv_n_tile, num_c_stage] in BF16 + c_stage0 = cute.slice_(sC, (None, None, 0)) # First stage of sC + for col in cutlass.range(0, self.pv_n_tile, unroll=1): + row = sfw_idx + if row < Int32(128): + c_stage0[row, col] = sO_acc[row, col].to(self.o_dtype) + + # TMA store sC -> GMEM + cute.arch.fence_proxy("async.shared", space="cta") + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + c_pipe.producer_acquire() + cute.copy(tma_c, c_stage0, tCgC[(None, None, Int32(0))]) + c_pipe.producer_commit() + c_pipe.producer_tail() + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) \ No newline at end of file diff --git a/tests/unit/test_smem_acc.py b/tests/unit/test_smem_acc.py new file mode 100644 index 00000000..547d9d0d --- /dev/null +++ b/tests/unit/test_smem_acc.py @@ -0,0 +1,112 @@ +""" +Test SMEM accumulator FMHA kernel: multi-KV-tile with in-kernel O accumulation. +No Python KV merge needed — the kernel handles acc_scale internally. +""" +import torch, math, sys +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha_smem_acc import FmhaKernel + + +def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=True): + m = 128 + n_kv_tiles = s_k // 128 + torch.manual_seed(42) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + + # FP32 reference + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_norm = (attn_exp / attn_sum) @ v.float() + ref_unnorm = attn_exp @ v.float() + + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + row_sums_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=normalize) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Compile + v_tile = v[:, 0:pv_n_tile].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor)) + + print(f' hd={hd}, s_k={s_k} ({n_kv_tiles} KV tiles, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, mRS) + + for nt in range(n_pv_tiles): + v_start = nt * pv_n_tile + v_end = v_start + pv_n_tile + v_tile = v[:, v_start:v_end].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor.zero_() + row_sums_tensor.zero_() + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor)) + + compiled(mQ, mK, mV, mC, stream, mLSE, mRS) + torch.cuda.synchronize() + + c[:, v_start:v_end, :] = c_tile + + out = c[:, :, 0].float() + + if normalize: + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0) + ).item() + ref = ref_norm + else: + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) + ).item() + ref = ref_unnorm + + status = "PASS" if cos >= 0.99 else "FAIL" + print(f' hd={hd}, s_k={s_k} ({n_kv_tiles} tiles): cos {cos:.6f} {status}') + return cos + + +def test(): + print("=== SMEM Accumulator FMHA: In-Kernel Multi-KV-Tile O Accumulation ===\n") + + # Single KV tile (s_k=128): should work like fmha.py + print("--- Single KV tile (s_k=128) ---") + test_smem_acc(64, 128) + test_smem_acc(128, 128) + + # Multi KV tile: the SMEM accumulator approach should handle this correctly + print("\n--- Multi KV tile (s_k=256+) ---") + test_smem_acc(64, 256) + test_smem_acc(64, 384) + test_smem_acc(64, 512) + test_smem_acc(128, 256) + + +if __name__ == '__main__': + test()