From 0f943ab48cf5abe38661040f63a7f4b0eb5d7f28 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 03:31:04 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20SMEM-P=20via=20gP=E2=86=92TMA=E2=86=92s?= =?UTF-8?q?P=20path=20(register=E2=86=92GMEM=E2=86=92SMEM)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha.py | 52 +++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 2ed27d7e..67bd85e4 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -113,19 +113,35 @@ class FmhaKernel: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) + + # SMEM-P: gP buffer and TMA for P (GMEM→SMEM via TMA) + if self.use_smem_p: + p_s = cute.slice_(self.p_smem_s,(None,None,None,0)) + gP = torch.zeros(128, self.s_k, dtype=torch.bfloat16, device='cuda') + mgP = ct.from_dlpack(gP).mark_layout_dynamic(leading_dim=ct.get_leading_dim(gP)) + tma_p = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileG2SOp(), mgP, p_s, self.qk_mma_tiler) + else: + # Dummy gP and tma_p (not used, dead-code-eliminated) + gP = torch.zeros(128, self.s_k, dtype=torch.bfloat16, device='cuda') + mgP = ct.from_dlpack(gP).mark_layout_dynamic(leading_dim=ct.get_leading_dim(gP)) + # Create a dummy TMA using the V SMEM layout (same structure, unused) + v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + tma_p = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileG2SOp(), mgP, v_s, self.qk_mma_tiler) # Always create a valid mLSE tensor for the kernel. # CuTeDSL doesn't support None parameters in @cute.kernel. # For normalize=True, mLSE is unused (dead-code-eliminated by compiler). if const_expr(lse is None): lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,tma_p,mgP,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, tma_p, mgP, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + if const_expr(self.use_smem_p): + cpasync.prefetch_descriptor(tma_p) @cute.struct class SS: @@ -226,6 +242,12 @@ class FmhaKernel: sh.commit() softmax_done_bar.arrive_and_wait() pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) + if const_expr(self.use_smem_p): + # SMEM-P: TMA load gP → sP (MMA warp does this after barrier) + tPgP, tPsP = cpasync.tma_partition(tma_p, 0, cute.nvgpu.OperandMajorMode.M, cute.group_modes(sP,0,3), cute.group_modes(mgP,0,3)) + cute.copy(tma_p, tPsP[(None,0,None,0)], tPgP[(None,0,None,0)], tma_bar_ptr=st.s_bar.data_ptr()) + cpasync.commit_group() + cpasync.wait_group(0) if not self.use_smem_p: # TMEM-P: PV reads P from TMEM for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): @@ -366,17 +388,19 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using coordinate-indexed store. - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] - k_coord = coord[1] - k0 = k_coord % 16 - k1 = (k_coord // 16) % 4 - k2 = k_coord // 64 - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] - cute.arch.fence_proxy("async.shared", space="cta") + # SMEM-P: write P to gP (global memory), then TMA loads gP→sP. + # rP_bf16 and tCgP are both partitioned by the QK MMA's C-fragment, + # so they have the same thread→value mapping. A simple element-wise + # copy from rP_bf16 to tCgP puts P values at the correct gP positions. + gP_local = cute.local_tile(mgP, (128, self.s_k), (0, 0)) + tCgP = qk_thr.partition_C(gP_local) + # Flatten both tensors for element-wise copy + rP_flat = cute.make_tensor(rP_bf16.iterator, cute.coalesce(rP_bf16.layout)) + gP_flat = cute.make_tensor(tCgP.iterator, cute.coalesce(tCgP.layout)) + # Copy element-by-element (both should have 128 values per thread) + for idx in cutlass.range(cute.size(rP_flat), vectorize=True): + gP_flat[idx] = rP_flat[idx] + cute.arch.fence_proxy("async.global", space="cta") if kt > 0: for i in range(n_corr_tiles): tTMrO_i_ = tTMrO[None, i] @@ -440,7 +464,7 @@ class FmhaKernel: if sfw_idx == 0: _ln2 = Float32(0.6931471805599453) # ln(2) lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[0] = lse_val + mLSE[0] = lse_val.to(self.q_dtype) tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)