From d36b727898eb3473bbfea4e7d9a916ef225bd206 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 06:01:02 +0000 Subject: [PATCH] D1.1: Add SMEM-P path behind use_smem_p flag (stub: zero sP) --- dsv4/kernels/attention/fmha.py | 74 +++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index b6777452..4cb3682c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -15,10 +15,11 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None): self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 + self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -45,18 +46,30 @@ class FmhaKernel: self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + # P SMEM layout (PV A-operand) — used for SMEM-P path + self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_as) pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_as) - self.tmem_s0_offset = 0; self.tmem_p0_offset = 32 - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - p_end = self.tmem_p0_offset + p_cols_fp32 - s_cols = self.qk_mma_tiler[1] - o_after = max(s_cols, p_end) - self.tmem_o0_offset = ((o_after + 31) // 32) * 32 - o_cols = find_tmem_tensor_col_offset(tOtO) - total = self.tmem_o0_offset + o_cols + self.tmem_s0_offset = 0 + if not self.use_smem_p: + # TMEM-P: S at 0, P at 32, O after P and S + self.tmem_p0_offset = 32 + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + p_end = self.tmem_p0_offset + p_cols_fp32 + s_cols = self.qk_mma_tiler[1] + o_after = max(s_cols, p_end) + self.tmem_o0_offset = ((o_after + 31) // 32) * 32 + o_cols = find_tmem_tensor_col_offset(tOtO) + total = self.tmem_o0_offset + o_cols + else: + # SMEM-P: P not in TMEM. S and O share TMEM (sequential). + self.tmem_p0_offset = -1 # unused + self.tmem_o0_offset = 0 + s_cols = self.qk_mma_tiler[1] + o_cols = find_tmem_tensor_col_offset(tOtO) + total = max(s_cols, o_cols) self.num_tmem_alloc_cols = 1 while self.num_tmem_alloc_cols < total: self.num_tmem_alloc_cols *= 2 @@ -83,7 +96,9 @@ class FmhaKernel: self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c) qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), tcgen05.OperandSource.TMEM) + pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K + pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), pv_source) self._setup(qk_mma, pv_mma) q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) @@ -91,10 +106,10 @@ class FmhaKernel: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: @@ -123,6 +138,7 @@ class FmhaKernel: sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner) gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) @@ -150,12 +166,14 @@ class FmhaKernel: tOtO = pv_thr.make_fragment_C(pv_as) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + # PV A-operand: always define both TMEM and SMEM paths (CuTeDSL scoping) tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) tOrP_base = pv_thr.make_fragment_A(tP) tOrP = tOrP_base[(None,None,None,0)] tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0), tOrP.layout) + tCrP = pv_mma.make_fragment_A(sP) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) @@ -191,9 +209,16 @@ class FmhaKernel: sh.commit() softmax_done_bar.arrive_and_wait() pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) + if not self.use_smem_p: + # TMEM-P: PV reads P from TMEM + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + # SMEM-P: PV reads P from SMEM + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store() kvh.release() acc_pipe.producer_commit(acc_st); acc_st.advance() @@ -216,10 +241,11 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # P store atoms + # P store atoms (always defined for CuTeDSL scoping; only used when use_smem_p=False) p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout) + # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) + tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) thr_store = tiled_tmem_store.get_slice(sfw_idx) @@ -294,8 +320,16 @@ class FmhaKernel: s_vec = tTMEM_LOADrS_frg[None, j].load() rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) - cute.arch.fence_view_async_tmem_store() + if not self.use_smem_p: + # TMEM-P: store P to TMEM via register bridge + cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + else: + # SMEM-P: TODO — write P to SMEM via make_tiled_copy_C(store_atom, qk_mma) + # For now, zero sP as stub. PV will produce garbage with SMEM-P path. + for j in cutlass.range(cute.size(sP), vectorize=True): + sP[j] = BFloat16(0.0) + cute.arch.fence_proxy("async.shared", space="cta") # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: