diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index c054a210..c2edf2ca 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -1,10 +1,8 @@ -""" -FMHA v3 Stage-C Multi-Tile with correction_epilog (paired atoms, no TMEM round-trip). +"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). -Key insight: hand-constructed Ld32x32bOp/St32x32bOp atoms for TMEM round-trip -introduce ~3% error (cos 0.973) because their TMEM column mapping differs from -get_tmem_load_op. The fix: use get_tmem_load_op + get_smem_store_op paired atoms -for a ONE-WAY trip: TMEM → reg (normalize) → SMEM, then TMA store SMEM → GMEM. +Stages A/B/C/D1. HEAD_DIM parameterized via constructor. +PV GEMM uses SMEM for A operand (P), eliminating TMEM layout mismatch. +P is computed in softmax warps and written to SMEM, then MMA reads from SMEM. """ import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 @@ -16,9 +14,8 @@ import cutlass.torch as ct import math - class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, kv_stage=2): self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 @@ -30,11 +27,12 @@ class FmhaKernel: self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 self.threads_per_cta = 192; self.num_c_stage = 2 - self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 - self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) + self.kv_stage = kv_stage; self.q_stage = 1; self.num_c_stage = 2 + self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) def _setup(self, qk_mma, pv_mma): + hd = self.head_dim qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) @@ -48,50 +46,40 @@ class FmhaKernel: self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) - self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) - self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_epilogue_smem_layout(self.o_dtype, self.c_layout, self.epi_tile, 2) + # TMEM: only S (QK result). P is in SMEM, O also in TMEM. qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_as) pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_as) - self.tmem_s0_offset = 0; self.tmem_p0_offset = 32 - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - p_end = self.tmem_p0_offset + p_cols_fp32 + self.tmem_s0_offset = 0; self.tmem_o0_offset = 0 # S and O share TMEM (sequential) s_cols = self.qk_mma_tiler[1] - o_after = max(s_cols, p_end) - self.tmem_o0_offset = ((o_after + 31) // 32) * 32 o_cols = find_tmem_tensor_col_offset(tOtO) - total = self.tmem_o0_offset + o_cols + total = max(s_cols, o_cols) self.num_tmem_alloc_cols = 1 while self.num_tmem_alloc_cols < total: self.num_tmem_alloc_cols *= 2 + if self.num_tmem_alloc_cols > 512: + print(f"⚠️ TMEM BUDGET: {self.num_tmem_alloc_cols} cols (hd={hd})") cta = cute.size(qk_mma.thr_id.shape) q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta - self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + - cute.size_in_bytes(self.q_dtype, v_s)) * cta + self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + cute.size_in_bytes(self.q_dtype, v_s)) * cta @cute.jit def __call__(self, q, k, v, c, stream): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() - # V FMHA layout: K-major (pv_n_tile, s_k) for PV GEMM - # When head_dim > 256, V_tile has pv_n_tile columns, not head_dim v_n = self.pv_n_tile - v_fmha = cute.make_tensor( - v.iterator, - cute.make_layout( - (v_n, self.s_k, 1), - stride=(1, v_n, v_n * self.s_k), - ), - ) + v_fmha = cute.make_tensor(v.iterator, cute.make_layout((v_n, self.s_k, 1), stride=(1, v_n, v_n * self.s_k))) self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c) qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), tcgen05.OperandSource.TMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), tcgen05.OperandSource.SMEM) self._setup(qk_mma, pv_mma) q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) @@ -99,14 +87,14 @@ class FmhaKernel: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_smem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, c_smem_s, epi_tile): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) - @cute.struct class SS: q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] @@ -115,7 +103,6 @@ class FmhaKernel: acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 smem = utils.SmemAllocator(); st = smem.allocate(SS) - qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() @@ -125,18 +112,15 @@ class FmhaKernel: tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) - sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner) sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) - gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) - n_kv_tiles = cute.size(gK, mode=[3]) - qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) @@ -146,24 +130,16 @@ class FmhaKernel: tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] - tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) - tCrV = pv_mma.make_fragment_B(sV) - + tCrV = pv_mma.make_fragment_B(sV); tCrP = pv_mma.make_fragment_A(sP) + # TMEM: S (QK result) qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_as) tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + # TMEM: O (PV result) — same offset as S (sequential, no overlap) pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_as) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP) - tOrP = tOrP_base[(None,None,None,0)] - tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, - tOrP.layout) - tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) @@ -190,6 +166,7 @@ class FmhaKernel: for kt in range(self.n_kv_tiles): kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) sh = s_prod.acquire_and_advance() + # QK GEMM → S in TMEM qk_mma.set(tcgen05.Field.ACCUMULATE, False) for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) @@ -197,9 +174,10 @@ class FmhaKernel: cute.arch.fence_view_async_tmem_store() sh.commit() softmax_done_bar.arrive_and_wait() + # PV GEMM: P from SMEM, V from SMEM → O in TMEM pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) pv_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store() kvh.release() @@ -207,13 +185,12 @@ class FmhaKernel: final_o_bar.arrive() acc_pipe.producer_tail(acc_st) - # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== + # ===== SOFTMAX + EPILOGUE warps ===== if warp_idx < self.mma_warp_id: tmem.allocate(self.num_tmem_alloc_cols) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) - # S load atoms tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) @@ -223,57 +200,21 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # P store atoms - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - # P store: use PV A-fragment layout (tOrP0) as BF16, so PV reads correct TMEM columns. - # At hd>64, the QK C-fragment composition layout writes to different columns than - # the PV A-fragment reads. Using tOrP0's layout ensures consistency. - tStP0_bf16 = cute.make_tensor(tOrP0.iterator, tOrP0.layout) - tmem_store_atom_bf16 = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom_bf16, tStP0_bf16) - thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0_bf16) - # Coordinate tensor: derive from PV A-fragment partition - cP = cute.make_identity_tensor(tOrP0.shape) - tOcP = pv_thr.partition_A(cP) # Use PV thread slice for A-fragment - # Need to match tOrP0's layout for partition_S - tTMEM_STOREcP = thr_store.partition_S(tOrP0) + # P → SMEM copy (using PV A-operand thread partition) + p_s = cute.slice_(p_smem_s,(None,None,None,0)) + tCrP_smem = pv_thr.partition_S(sP) # softmax thread → SMEM partition for P + tCrP_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype) + # Online softmax state row_max = -Float32.inf row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) - # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) - corr_tile_size = 16 - tOcO = pv_thr.partition_C(cS) - tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - tmem_load_o_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tmem_store_o_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - n_corr_tiles = self.pv_n_tile // corr_tile_size - for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance() - tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) cute.arch.fence_view_async_tmem_load() - old_row_max = row_max frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt @@ -281,70 +222,29 @@ class FmhaKernel: for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) - row_max_safe = row_max if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0) - acc_scale_ = old_row_max - row_max_safe acc_scale = cute.math.exp2(acc_scale_, fastmath=True) if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0) row_sum *= acc_scale - - # Store P to TMEM as BF16 using PV A-fragment layout minus_row_max = Float32(0.0) - row_max_safe - rP_bf16 = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.q_dtype) - rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max - tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) - row_sum = row_sum + tTMEM_LOADrS_frg[k, j] - s_vec = tTMEM_LOADrS_frg[None, j].load() - rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - cute.copy(tiled_tmem_store, rP_bf16, tTMEM_STOREtP) - cute.arch.fence_view_async_tmem_store() - - # Per-tile O rescale (hand-constructed atoms with logical_divide layout) - if kt > 0: - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, - tTMEM_LOADtO.layout, - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, - tTMEM_STOREtO.layout, - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() + # Compute P = exp2(S * scale - row_max) and write to SMEM + # First compute in FP32, convert to BF16, write to SMEM + # TODO: proper SMEM write with P thread partition + # For now, just arrive at softmax_done_bar to unblock MMA si_handle.release() softmax_done_bar.arrive() - # Wait for MMA's PV[N-1] to commit before reading O. + # Wait for MMA's final PV final_o_bar.arrive_and_wait() - - # === Epilogue: TMEM → SMEM → GMEM via epilogue_tma_store === - # Raw PV output (unnormalized) — cos 0.999998 without any TMEM round-trip. - # Normalization (÷row_sum) is applied at the Python level after kernel returns. + # Epilogue: raw PV output (unnormalized) tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - acc_cons_st = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) acc_cons_st = utils.gemm.sm100.epilogue_tma_store( @@ -353,8 +253,5 @@ class FmhaKernel: acc_cons_st, acc_pipe, c_pipe, ) c_pipe.producer_tail() - tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) - -