From 4fe9bbab48d5478739fc6bf89207c4ba7f44fbc5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 07:04:59 +0000 Subject: [PATCH] add back in the archived code --- .../archive/fmha_backup_pre_epilog.py | 492 +++++++++++++++ .../attention/archive/fmha_backup_v2.py | 592 ++++++++++++++++++ .../attention/archive/fmha_smem_acc.py | 592 ++++++++++++++++++ 3 files changed, 1676 insertions(+) diff --git a/dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py b/dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py index 29657ec4..7bca0813 100644 --- a/dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py +++ b/dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py @@ -13,4 +13,496 @@ WHY ARCHIVED: Superseded by the current fmha.py which has: This backup uses the old TMEM round-trip approach which is FUNDAMENTALLY BROKEN (Ld32x32bOp/St32x32bOp column mismatch, even NO-OP round-trip produces ~3% error). + +FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). + +Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. +P stored to TMEM via register bridge, PV reads from TMEM. +O rescale via correction_rescale atoms, O normalization via TMEM round-trip. """ +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +from cutlass.utils.blackwell_helpers import get_smem_store_op +import cuda.bindings.driver as cuda +import cutlass.torch as ct +import math + + +class FmhaKernel: + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True): + self.head_dim = head_dim + self.s_k = s_k + self.n_kv_tiles = s_k // 128 + self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 + self.n_pv_tiles = head_dim // self.pv_n_tile + self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) + self.normalize = normalize # D5a: False = emit un-normalized O + lse + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192; self.num_c_stage = 2 + self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 + self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) + self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + + def _setup(self, qk_mma, pv_mma): + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (128, 128, qk_ik * 4) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) + self.mma_tiler = self.qk_mma_tiler + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) + self.c_layout = LayoutEnum.ROW_MAJOR + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) + self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + # P SMEM layout (PV A-operand) — used for SMEM-P path + self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + self.tmem_s0_offset = 0 + if not self.use_smem_p: + # TMEM-P: S at 0, P at 32, O after P and S + self.tmem_p0_offset = 32 + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + p_end = self.tmem_p0_offset + p_cols_fp32 + s_cols = self.qk_mma_tiler[1] + o_after = max(s_cols, p_end) + self.tmem_o0_offset = ((o_after + 31) // 32) * 32 + o_cols = find_tmem_tensor_col_offset(tOtO) + total = self.tmem_o0_offset + o_cols + else: + # SMEM-P: P not in TMEM. S and O share TMEM (sequential). + self.tmem_p0_offset = -1 # unused + self.tmem_o0_offset = 0 + s_cols = self.qk_mma_tiler[1] + o_cols = find_tmem_tensor_col_offset(tOtO) + total = max(s_cols, o_cols) + self.num_tmem_alloc_cols = 1 + while self.num_tmem_alloc_cols < total: + self.num_tmem_alloc_cols *= 2 + # tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only) + # = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0 + self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int + cta = cute.size(qk_mma.thr_id.shape) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta + self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + + cute.size_in_bytes(self.q_dtype, v_s)) * cta + + @cute.jit + def __call__(self, q, k, v, c, stream, lse=None): + self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + v_fmha = cute.make_tensor( + v.iterator, + cute.make_layout( + (self.pv_n_tile, self.s_k, 1), + stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), + ), + ) + self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) + pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K + pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) + self._setup(qk_mma, pv_mma) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) + epi_s = cute.select(self.c_smem_s,mode=[0,1]) + tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) + # Always create a valid mLSE tensor for the kernel. + # CuTeDSL doesn't support None parameters in @cute.kernel. + # For normalize=True, mLSE is unused (dead-code-eliminated by compiler). + if const_expr(lse is None): + lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx,_,_ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] + kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] + s_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] + tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() + softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) + final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner) + + gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) + gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) + gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) + gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) + n_kv_tiles = cute.size(gK, mode=[3]) + + qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) + tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally. + # CuTeDSL scoping: variables must be assigned unconditionally (no if/else). + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) + tOrP = tOrP_base[(None,None,None,0)] + tCrP = pv_mma.make_fragment_A(sP) + # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). + # self.tOrP0_offset is pre-computed in _setup as a Python int. + # Use const_expr if/else for compile-time conditional. + if const_expr(self.tOrP0_offset > 0): + tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) + else: + tOrP0 = tOrP + + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ===== TMA LOAD warp ===== + if warp_idx == self.tma_warp_id: + qp.reset(); qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + qp.tail() + kvp.reset(); pk = kvp.try_acquire() + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): + kvh = kvp.acquire_and_advance(pk) + cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + pk = cutlass.Boolean(1) + kvp.tail() + + # ===== MMA warp ===== + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + qc.reset(); qh = qc.wait_and_advance(); qh.release() + kvc.reset(); pk = kvc.try_wait() + acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + acc_pipe.producer_acquire(acc_st) + for kt in range(self.n_kv_tiles): + kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) + sh = s_prod.acquire_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + sh.commit() + softmax_done_bar.arrive_and_wait() + pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) + if not self.use_smem_p: + # TMEM-P: PV reads P from TMEM + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + # SMEM-P: PV reads P from SMEM + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh.release() + acc_pipe.producer_commit(acc_st); acc_st.advance() + final_o_bar.arrive() + acc_pipe.producer_tail(acc_st) + + # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) + + # S load atoms + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + + # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) + tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) + tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) + thr_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtP = thr_store.partition_D(tStP0) + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tTMEM_STOREcP = thr_store.partition_S(tScP) + + # P SMEM copy atoms: SMEM-P + # Per CUTLASS LLM guidance: use make_cotiled_copy with TV layout + # from TMEM load partition, remapped to sP's codomain. + # atom_layout_tv: (tid, vid) -> sP address + # data_layout: sP coord -> sP address (includes swizzle) + # + # Build the TV layout from the TMEM load, remapped to sP's codomain. + # The TMEM load's TV layout maps (tid, vid) -> tStS_addr. + # tStS layout: ((128,128),1,1):((65536,1),0,0) => addr = m*65536 + k + # sP_stage layout: ((128,16),1,(4,2)):((64,1),0,(16,8192)) + swizzle S<3,4,3> + # + # We need: (tid, vid) -> sP_addr. + # Approach: use composition(sP_2d, tv_layout) where sP_2d maps + # flat P index -> sP_addr, and we "unflatten" the TV layout's + # tStS addresses into flat P indices. + # + # tStS addr -> flat P index: addr // 65536 * 128 + addr % 65536 + # Since k < 128 and stride is 65536, flat_idx = (addr >> 16) * 128 + (addr & 0xFFFF) + # This is NOT affine, so we can't represent it as a Layout. + # + # FALLBACK: Use the coordinate-indexed approach (scalar SMEM writes). + # This works but gives ~0.04 cosine loss vs TMEM-P at hd=64. + # The make_cotiled_copy approach is tracked for future optimization. + _sP_nostage = sP[(None, None, None, 0)] # remove stage dim + + row_max = -Float32.inf + row_sum = Float32(0.0) + scale_log2 = Float32(self.scale_softmax_log2) + + # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) + corr_tile_size = 16 + tOcO = pv_thr.partition_C(cS) + tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + tmem_load_o_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) + tmem_store_o_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.acc_dtype, + ) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) + thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) + thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) + tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) + tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) + n_corr_tiles = self.pv_n_tile // corr_tile_size + + for kt in range(self.n_kv_tiles): + si_handle = s_cons.wait_and_advance() + + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) + cute.arch.fence_view_async_tmem_load() + + old_row_max = row_max + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) + + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = Float32(0.0) + + acc_scale_ = old_row_max - row_max_safe + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if old_row_max == -cutlass.Float32.inf: + acc_scale = Float32(0.0) + row_sum *= acc_scale + + rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) + rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) + minus_row_max = Float32(0.0) - row_max_safe + + rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max + tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] + s_vec = tTMEM_LOADrS_frg[None, j].load() + rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) + + if not self.use_smem_p: + # TMEM-P: store P to TMEM via register bridge + cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + else: + # SMEM-P: write P to sP using coordinate-indexed store. + # Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates. + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + cute.arch.fence_proxy("async.shared", space="cta") + if kt > 0: + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[k] = tTMrO_i[k] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + + si_handle.release() + softmax_done_bar.arrive() + + # Wait for MMA's PV[N-1] to commit before reading O. + final_o_bar.arrive_and_wait() + + # === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout === + tTMrO_noop = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO_noop[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + + # === Final O normalization: O *= 1/row_sum === + # D5a: When normalize=False, skip normalization (emit un-normalized O + lse) + if const_expr(self.normalize): + inv_row_sum = Float32(1.0) / row_sum + + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout + ) + + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + if const_expr(self.normalize): + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + + cute.arch.fence_view_async_tmem_store() + + # Epilogue: TMEM → SMEM → GMEM via TMA store. + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, + 0, const_expr(lambda x: x), (0, 0, 0), + acc_cons_st, acc_pipe, c_pipe, + ) + c_pipe.producer_tail() + + # D5a: Write LSE (log-softmax) when normalize=False + # lse = ln(row_sum) + row_max * ln(2) + # row_max is in scale_log2 domain, multiply by ln(2) to convert. + if const_expr(not self.normalize): + _row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + _row_max_safe = Float32(0.0) + if sfw_idx == 0: + _ln2 = Float32(0.6931471805599453) # ln(2) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 + mLSE[0] = lse_val + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + diff --git a/dsv4/kernels/attention/archive/fmha_backup_v2.py b/dsv4/kernels/attention/archive/fmha_backup_v2.py index fab4e7e0..53b3a39f 100644 --- a/dsv4/kernels/attention/archive/fmha_backup_v2.py +++ b/dsv4/kernels/attention/archive/fmha_backup_v2.py @@ -3,4 +3,596 @@ ARCHIVED: FMHA kernel backup v2. Intermediate state during the SMEM accumulator development. Superseded by the current fmha.py. Kept for git-archaeology only. + +FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). + +Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. +P stored to TMEM via register bridge, PV reads from TMEM. +O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration). +Normalization via final TMA store (SMEM→GMEM). +D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column +mapping mismatch). SMEM accumulator avoids round-trip entirely. """ +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +from cutlass.utils.blackwell_helpers import get_smem_store_op +from cutlass.utils.gemm.sm100 import ( + transform_partitioned_tensor_layout, + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, +) +# D1.5: TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken. +# Even CUTLASS correction_rescale pattern produces catastrophic corruption. +# SMEM accumulator approach: one-way TMEM→REGS→SMEM per kt iteration. +import cuda.bindings.driver as cuda +import cutlass.torch as ct +import math + + +class FmhaKernel: + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False): + # D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to + # positions >= n_comp. D3/D4 masks also only apply to SWA region. + # When n_comp is None or 0, no offset (backward compatible). + self.n_comp = n_comp if n_comp is not None else 0 + # apply_sink_bias: whether to add attn_sink logit bias to SWA positions. + # Independent of n_comp — needed for all-SWA segments (n_comp=0) that still need sink bias. + # When True, adds sink_bias to positions >= n_comp (which is 0 → all positions). + self.apply_sink_bias = apply_sink_bias + self.head_dim = head_dim + self.s_k = s_k + self.n_kv_tiles = s_k // 128 + self.pv_n_tile = min(head_dim, 256) + # At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB, + # making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512 + # (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256. + if head_dim > 256: + self.pv_n_tile = 128 + self.n_pv_tiles = head_dim // self.pv_n_tile + self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) + self.num_query_heads = num_query_heads + self.batch_size = batch_size + self.normalize = normalize # D5a: False = emit un-normalized O + lse + self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens + self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + # K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget + self.k_tile = min(head_dim, 256) + self.n_k_sub_tiles = head_dim // self.k_tile + self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd + self.q_stage = 1 + self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 + self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) + self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + + def _setup(self, qk_mma, pv_mma): + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + # QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements. + # The tiler K must be head_dim so the QK loop iterates over all K sub-tiles. + self.qk_mma_tiler = (128, 128, self.k_tile) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) + self.mma_tiler = self.qk_mma_tiler + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) + self.c_layout = LayoutEnum.ROW_MAJOR + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) + self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + # P SMEM layout (PV A-operand) — used for SMEM-P path + self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + self.tmem_s0_offset = 0 + if not self.use_smem_p: + # TMEM-P: S at 0, P at 32, O after P and S + self.tmem_p0_offset = 32 + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + p_end = self.tmem_p0_offset + p_cols_fp32 + s_cols = self.qk_mma_tiler[1] + o_after = max(s_cols, p_end) + self.tmem_o0_offset = ((o_after + 31) // 32) * 32 + o_cols = find_tmem_tensor_col_offset(tOtO) + total = self.tmem_o0_offset + o_cols + else: + # SMEM-P: P not in TMEM. S and O share TMEM (sequential). + self.tmem_p0_offset = -1 # unused + self.tmem_o0_offset = 0 + s_cols = self.qk_mma_tiler[1] + o_cols = find_tmem_tensor_col_offset(tOtO) + total = max(s_cols, o_cols) + self.num_tmem_alloc_cols = 1 + while self.num_tmem_alloc_cols < total: + self.num_tmem_alloc_cols *= 2 + # tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only) + # = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0 + self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int + cta = cute.size(qk_mma.thr_id.shape) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta + self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + + cute.size_in_bytes(self.q_dtype, v_s)) * cta + + @cute.jit + def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None): + self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + v_fmha = cute.make_tensor( + v.iterator, + cute.make_layout( + (self.pv_n_tile, self.s_k, 1), + stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), + ), + ) + self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) + pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K + pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) + self._setup(qk_mma, pv_mma) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) + epi_s = cute.select(self.c_smem_s,mode=[0,1]) + tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) + # Always create a valid mLSE tensor for the kernel. + # CuTeDSL doesn't support None parameters in @cute.kernel. + if const_expr(lse is None): + lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) + if const_expr(swa_len is None): + # No SWA masking — pass max int (no positions masked) + swa_len = Int32(2147483647) + else: + swa_len = Int32(swa_len) + # D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions. + # When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the + # current head (n_h=1 in per-head launch mode). + if const_expr(sink_bias is None): + # D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory. + # Never actually read (const_expr(self.n_comp > 0) guards the read). + sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,))) + # else: sink_bias is already a CuTe tensor (caller must pass via ct.from_dlpack) + # Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension + # For single-head (n_h=1): grid=(1,1,1) — backward compatible + if const_expr(row_sums is None): + row_sums = cute.make_tensor(lse.iterator, lse.layout) + + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx,_,_ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] + kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] + s_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] + tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() + softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) + final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) + # D1.5: pv_done_bar for SMEM accumulator approach. + # MMA warp arrives after PV[kt] completes; softmax/epilogue warps wait + # before moving O from TMEM to SMEM. + pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id)) + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) + # sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB. + # TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput. + sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) + # sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM) + if const_expr(self.use_smem_p): + _p_layout = p_smem_s.outer + _p_swizzle = p_smem_s.inner + else: + _p_layout = cute.make_layout(((1,1),1,(1,1),1)) + _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) + + gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) + gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) + gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) + gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) + n_kv_tiles = cute.size(gK, mode=[3]) + + qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) + tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally. + # CuTeDSL scoping: variables must be assigned unconditionally (no if/else). + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) + tOrP = tOrP_base[(None,None,None,0)] + # tCrP is only used in SMEM-P path. Define unconditionally for CuTeDSL scoping. + tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP) + # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). + # self.tOrP0_offset is pre-computed in _setup as a Python int. + # Use const_expr if/else for compile-time conditional. + if const_expr(self.tOrP0_offset > 0): + tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) + else: + tOrP0 = tOrP + + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ===== TMA LOAD warp ===== + if warp_idx == self.tma_warp_id: + if const_expr(self.n_k_sub_tiles > 1): + # K sub-tiling path (hd>256): use cutlass.range loop to avoid IR explosion + # from Python range unrolling. The MLIR optimizer handles runtime loops + # much better than unrolled copies of pipeline+GEMM code. + qp.reset() + kvp.reset() + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + kvh = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + # Load V[0] + kvh_v = kvp.acquire_and_advance() + cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier) + qp.tail() + kvp.tail() + else: + # Original pipeline path (hd≤256) + qp.reset(); qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + qp.tail() + kvp.reset(); pk = kvp.try_acquire() + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): + kvh = kvp.acquire_and_advance(pk) + cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + pk = cutlass.Boolean(1) + kvp.tail() + + # ===== MMA warp ===== + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + if const_expr(self.n_k_sub_tiles > 1): + # K sub-tiling path (hd>256): cutlass.range loop (runtime, not unrolled) + qc.reset() + kvc.reset() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qc.wait_and_advance(); qh.release() + kvh = kvc.wait_and_advance() + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh.release() + # After all k_sub: S has full QK for this kt + cute.arch.fence_view_async_tmem_store() + softmax_done_bar.arrive() + softmax_done_bar.arrive_and_wait() + pv_mma.set(tcgen05.Field.ACCUMULATE, False) + # Load V: consume from K/V pipeline + kvh_v = kvc.wait_and_advance() + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh_v.release() + pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM + final_o_bar.arrive() + else: + # Original pipeline path (hd≤256) + qc.reset(); qh = qc.wait_and_advance(); qh.release() + kvc.reset(); pk = kvc.try_wait() + acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + acc_pipe.producer_acquire(acc_st) + for kt in range(self.n_kv_tiles): + kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) + sh = s_prod.acquire_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + sh.commit() + softmax_done_bar.arrive_and_wait() + pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh.release() + pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM + acc_pipe.producer_commit(acc_st); acc_st.advance() + final_o_bar.arrive() + acc_pipe.producer_tail(acc_st) + + # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) + + # S load atoms + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + + # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) + tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) + tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) + thr_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtP = thr_store.partition_D(tStP0) + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tTMEM_STOREcP = thr_store.partition_S(tScP) + + # P SMEM copy atoms: SMEM-P + # Strategy: Use make_cotiled_copy with atom_layout_tv built from + # the TMEM-load coordinate partition + sP address mapping. + # + # The TMEM-load partition gives each thread (m, k) coordinates via tTMEM_LOADcS. + # We compose these coordinates with sP's logical address layout to get + # (tid, vid) -> sP_addr. Then make_cotiled_copy creates a proper TiledCopy. + # + # Key: sP's outer layout maps (m, k0, k1, k2) -> sP_addr with strides (64, 1, 16, 8192). + # We need to build atom_layout_tv in sP's flat address space, not tStS's. + # + # Step 1: Build sP address mapping in the same coordinate system as tStS. + # sP is indexed as ((m, k%16), 0, ((k//16)%4, k//64)) with strides ((64,1),0,(16,8192)). + # In the P matrix's (m, k) coordinate space: + # sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64) + # This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192)) + _sP_nostage = sP[(None, None, None, 0)] # remove stage dim + + row_max = -Float32.inf + row_sum = Float32(0.0) + scale_log2 = Float32(self.scale_softmax_log2) + + # ============================================================ + # D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH + # ================================================= + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts data (ratio = -11 billion). + # Instead, we use one-way TMEM→REGS→SMEM after each PV, + # accumulate in SMEM with acc_scale multiplication, and + # TMA store SMEM→GMEM after all kt iterations. + # + # For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store + # path works perfectly (cos=0.999998). The SMEM accumulator + # is only needed for n_kv_tiles > 1. + # ============================================================ + + # NOTE: The code below is the BROKEN TMEM round-trip approach. + # It's kept as reference but should NOT be used. + # The SMEM accumulator implementation is TODO. + + # prev_acc_scale: unused, kept for clarity. acc_scale at kt is used + # to rescale O from kt=0..kt-1 before PV[kt]. + prev_acc_scale = Float32(0.0) + + for kt in range(self.n_kv_tiles): + si_handle = s_cons.wait_and_advance() + + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) + cute.arch.fence_view_async_tmem_load() + + # D3/D4/D5c: In-kernel logit modification. + # After loading S from TMEM, modify logits for SWA positions: + # D5c: Add sink_bias (attn_sink) to positions >= n_comp + # D3: Mask positions >= n_comp + swa_len to -inf + # D4: Causal mask — SWA positions where k_coord > m_coord → -inf + # Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col). + # For kt > 0, absolute KV pos = kt*128 + k_coord. + if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias): + kt_offset = Int32(kt * 128) # KV position offset for this tile + # D5c: Read sink bias once (same for all positions in this head). + # Define unconditionally for CuTeDSL scoping (used when apply_sink_bias). + # The bias must be added in the SCALED-LOG2 domain: attn_sink * log2(e). + # But we add to the RAW logits before the scale_log2 multiply. + # Raw correction: attn_sink / scale → after * scale_log2 → attn_sink * log2(e) + sink_val = Float32(0.0) + if const_expr(self.apply_sink_bias): + sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax) + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] # query row position + k_coord = coord[1] # position within this KV tile + kv_pos = kt_offset + k_coord # absolute KV position + # D5c: Add sink bias to SWA positions (>= n_comp) + if const_expr(self.apply_sink_bias): + if kv_pos >= Int32(self.n_comp): + tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val + # D3: SWA length mask + should_mask = Boolean(0) + if const_expr(self.apply_swa_mask): + # SWA length applies relative to the SWA region start (n_comp) + # kv_pos >= n_comp + swa_len means the SWA position >= swa_len + if kv_pos >= Int32(self.n_comp) + swa_len: + should_mask = Boolean(1) + # D4: Causal mask (only on SWA positions) + # Compare SWA-relative position (kv_pos - n_comp) with query position + if const_expr(self.is_causal): + if kv_pos >= Int32(self.n_comp): + swa_pos = kv_pos - Int32(self.n_comp) + if swa_pos > m_coord: + should_mask = Boolean(1) + if should_mask: + tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf + + old_row_max = row_max + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) + + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = Float32(0.0) + + acc_scale_ = old_row_max - row_max_safe + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if old_row_max == -cutlass.Float32.inf: + acc_scale = Float32(0.0) + row_sum *= acc_scale + + rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) + rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) + minus_row_max = Float32(0.0) - row_max_safe + + rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max + tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] + s_vec = tTMEM_LOADrS_frg[None, j].load() + rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) + + if not self.use_smem_p: + # TMEM-P: store P to TMEM via register bridge + cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + else: + # SMEM-P: write P to sP using coordinate-indexed store. + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + cute.arch.fence_proxy("async.shared", space="cta") + # D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED. + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts O accumulator data. + # Production path for multi-KV-tile: Python KV merge (cos 0.999998). + # Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt). + # n_kv_tiles=1 is the only supported path for in-kernel processing. + + si_handle.release() + softmax_done_bar.arrive() + + # Wait for MMA's PV[N-1] to commit before reading O. + final_o_bar.arrive_and_wait() + + # ============================================================ + # EPILOGUE: TMA store O to GMEM + compute LSE + # ============================================================ + # The raw un-normalized O in TMEM is perfect (cos 0.999998). + # We use epilogue_tma_store which reads O from TMEM directly via + # the correct get_tmem_load_op layout — no round-trip needed. + # + # For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures + # O is correctly rescaled before this epilogue reads it. + # + # External normalization (D5a path): kernel outputs un-normalized O + + # LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum. + # This is exact and composes with D5c sink bias merge. + # ============================================================ + + # TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM) + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, + 0, const_expr(lambda x: x), (0, 0, 0), + acc_cons_st, acc_pipe, c_pipe, + ) + c_pipe.producer_tail() + + # Compute LSE: lse = ln(row_sum) + row_max * ln(2) + # Only when emitting un-normalized output (D5a path). + # When normalize=True, LSE is not needed (in-kernel normalization). + # + # Per-row LSE: each softmax thread (sfw_idx 0..127) handles one row. + # sfw_idx maps directly to the row index in the attention matrix. + # All 128 threads write independently to mLSE[sfw_idx] — no sync needed. + if const_expr(not self.normalize): + _row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + _row_max_safe = Float32(0.0) + _ln2 = Float32(0.6931471805599453) # ln(2) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 + mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val + # Also output row_sum for external normalization (D5c) + mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) diff --git a/dsv4/kernels/attention/archive/fmha_smem_acc.py b/dsv4/kernels/attention/archive/fmha_smem_acc.py index 555ea1f3..acb94d31 100644 --- a/dsv4/kernels/attention/archive/fmha_smem_acc.py +++ b/dsv4/kernels/attention/archive/fmha_smem_acc.py @@ -15,4 +15,596 @@ WHY IT DIDN'T SHIP: SMEM budget at hd=512 is already tight (192KB). Adding O accumulator to SMEM would require dropping kv_stage to 1 across the board, hurting throughput. The register-based approach in raw CUDA is better — registers are free. + +FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). + +Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. +P stored to TMEM via register bridge, PV reads from TMEM. +O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration). +Normalization via final TMA store (SMEM→GMEM). +D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column +mapping mismatch). SMEM accumulator avoids round-trip entirely. """ +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +from cutlass.utils.blackwell_helpers import get_smem_store_op +from cutlass.utils.gemm.sm100 import ( + transform_partitioned_tensor_layout, + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, +) +# D1.5: TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken. +# Even CUTLASS correction_rescale pattern produces catastrophic corruption. +# SMEM accumulator approach: one-way TMEM→REGS→SMEM per kt iteration. +import cuda.bindings.driver as cuda +import cutlass.torch as ct +import math + + +class FmhaKernel: + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False): + # D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to + # positions >= n_comp. D3/D4 masks also only apply to SWA region. + # When n_comp is None or 0, no offset (backward compatible). + self.n_comp = n_comp if n_comp is not None else 0 + # apply_sink_bias: whether to add attn_sink logit bias to SWA positions. + # Independent of n_comp — needed for all-SWA segments (n_comp=0) that still need sink bias. + # When True, adds sink_bias to positions >= n_comp (which is 0 → all positions). + self.apply_sink_bias = apply_sink_bias + self.head_dim = head_dim + self.s_k = s_k + self.n_kv_tiles = s_k // 128 + self.pv_n_tile = min(head_dim, 256) + # At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB, + # making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512 + # (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256. + if head_dim > 256: + self.pv_n_tile = 128 + self.n_pv_tiles = head_dim // self.pv_n_tile + self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) + self.num_query_heads = num_query_heads + self.batch_size = batch_size + self.normalize = normalize # D5a: False = emit un-normalized O + lse + self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens + self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + # K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget + self.k_tile = min(head_dim, 256) + self.n_k_sub_tiles = head_dim // self.k_tile + self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd + self.q_stage = 1 + self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 + self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) + self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + + def _setup(self, qk_mma, pv_mma): + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + # QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements. + # The tiler K must be head_dim so the QK loop iterates over all K sub-tiles. + self.qk_mma_tiler = (128, 128, self.k_tile) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) + self.mma_tiler = self.qk_mma_tiler + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) + self.c_layout = LayoutEnum.ROW_MAJOR + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) + self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + # P SMEM layout (PV A-operand) — used for SMEM-P path + self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + self.tmem_s0_offset = 0 + if not self.use_smem_p: + # TMEM-P: S at 0, P at 32, O after P and S + self.tmem_p0_offset = 32 + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + p_end = self.tmem_p0_offset + p_cols_fp32 + s_cols = self.qk_mma_tiler[1] + o_after = max(s_cols, p_end) + self.tmem_o0_offset = ((o_after + 31) // 32) * 32 + o_cols = find_tmem_tensor_col_offset(tOtO) + total = self.tmem_o0_offset + o_cols + else: + # SMEM-P: P not in TMEM. S and O share TMEM (sequential). + self.tmem_p0_offset = -1 # unused + self.tmem_o0_offset = 0 + s_cols = self.qk_mma_tiler[1] + o_cols = find_tmem_tensor_col_offset(tOtO) + total = max(s_cols, o_cols) + self.num_tmem_alloc_cols = 1 + while self.num_tmem_alloc_cols < total: + self.num_tmem_alloc_cols *= 2 + # tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only) + # = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0 + self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int + cta = cute.size(qk_mma.thr_id.shape) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta + self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + + cute.size_in_bytes(self.q_dtype, v_s)) * cta + + @cute.jit + def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None): + self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + v_fmha = cute.make_tensor( + v.iterator, + cute.make_layout( + (self.pv_n_tile, self.s_k, 1), + stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), + ), + ) + self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) + pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K + pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) + self._setup(qk_mma, pv_mma) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) + epi_s = cute.select(self.c_smem_s,mode=[0,1]) + tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) + # Always create a valid mLSE tensor for the kernel. + # CuTeDSL doesn't support None parameters in @cute.kernel. + if const_expr(lse is None): + lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) + if const_expr(swa_len is None): + # No SWA masking — pass max int (no positions masked) + swa_len = Int32(2147483647) + else: + swa_len = Int32(swa_len) + # D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions. + # When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the + # current head (n_h=1 in per-head launch mode). + if const_expr(sink_bias is None): + # D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory. + # Never actually read (const_expr(self.n_comp > 0) guards the read). + sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,))) + # else: sink_bias is already a CuTe tensor (caller must pass via ct.from_dlpack) + # Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension + # For single-head (n_h=1): grid=(1,1,1) — backward compatible + if const_expr(row_sums is None): + row_sums = cute.make_tensor(lse.iterator, lse.layout) + + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx,_,_ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] + kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] + s_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] + tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() + s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() + softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) + final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) + # D1.5: pv_done_bar for SMEM accumulator approach. + # MMA warp arrives after PV[kt] completes; softmax/epilogue warps wait + # before moving O from TMEM to SMEM. + pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id)) + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) + # sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB. + # TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput. + sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) + # sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM) + if const_expr(self.use_smem_p): + _p_layout = p_smem_s.outer + _p_swizzle = p_smem_s.inner + else: + _p_layout = cute.make_layout(((1,1),1,(1,1),1)) + _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) + sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) + + gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) + gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) + gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) + gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) + n_kv_tiles = cute.size(gK, mode=[3]) + + qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) + tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally. + # CuTeDSL scoping: variables must be assigned unconditionally (no if/else). + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) + tOrP = tOrP_base[(None,None,None,0)] + # tCrP is only used in SMEM-P path. Define unconditionally for CuTeDSL scoping. + tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP) + # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). + # self.tOrP0_offset is pre-computed in _setup as a Python int. + # Use const_expr if/else for compile-time conditional. + if const_expr(self.tOrP0_offset > 0): + tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) + else: + tOrP0 = tOrP + + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ===== TMA LOAD warp ===== + if warp_idx == self.tma_warp_id: + if const_expr(self.n_k_sub_tiles > 1): + # K sub-tiling path (hd>256): use cutlass.range loop to avoid IR explosion + # from Python range unrolling. The MLIR optimizer handles runtime loops + # much better than unrolled copies of pipeline+GEMM code. + qp.reset() + kvp.reset() + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + kvh = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + # Load V[0] + kvh_v = kvp.acquire_and_advance() + cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier) + qp.tail() + kvp.tail() + else: + # Original pipeline path (hd≤256) + qp.reset(); qh = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) + qp.tail() + kvp.reset(); pk = kvp.try_acquire() + for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): + kvh = kvp.acquire_and_advance(pk) + cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + pk = cutlass.Boolean(1) + kvp.tail() + + # ===== MMA warp ===== + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + if const_expr(self.n_k_sub_tiles > 1): + # K sub-tiling path (hd>256): cutlass.range loop (runtime, not unrolled) + qc.reset() + kvc.reset() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): + qh = qc.wait_and_advance(); qh.release() + kvh = kvc.wait_and_advance() + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh.release() + # After all k_sub: S has full QK for this kt + cute.arch.fence_view_async_tmem_store() + softmax_done_bar.arrive() + softmax_done_bar.arrive_and_wait() + pv_mma.set(tcgen05.Field.ACCUMULATE, False) + # Load V: consume from K/V pipeline + kvh_v = kvc.wait_and_advance() + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh_v.release() + pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM + final_o_bar.arrive() + else: + # Original pipeline path (hd≤256) + qc.reset(); qh = qc.wait_and_advance(); qh.release() + kvc.reset(); pk = kvc.try_wait() + acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + acc_pipe.producer_acquire(acc_st) + for kt in range(self.n_kv_tiles): + kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) + sh = s_prod.acquire_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + sh.commit() + softmax_done_bar.arrive_and_wait() + pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) + if not self.use_smem_p: + for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + else: + for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): + cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.arch.fence_view_async_tmem_store() + kvh.release() + pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM + acc_pipe.producer_commit(acc_st); acc_st.advance() + final_o_bar.arrive() + acc_pipe.producer_tail(acc_st) + + # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) + + # S load atoms + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + + # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) + tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) + tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) + thr_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtP = thr_store.partition_D(tStP0) + tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) + tScP = cute.make_tensor(tScS.iterator, tScP_layout) + tTMEM_STOREcP = thr_store.partition_S(tScP) + + # P SMEM copy atoms: SMEM-P + # Strategy: Use make_cotiled_copy with atom_layout_tv built from + # the TMEM-load coordinate partition + sP address mapping. + # + # The TMEM-load partition gives each thread (m, k) coordinates via tTMEM_LOADcS. + # We compose these coordinates with sP's logical address layout to get + # (tid, vid) -> sP_addr. Then make_cotiled_copy creates a proper TiledCopy. + # + # Key: sP's outer layout maps (m, k0, k1, k2) -> sP_addr with strides (64, 1, 16, 8192). + # We need to build atom_layout_tv in sP's flat address space, not tStS's. + # + # Step 1: Build sP address mapping in the same coordinate system as tStS. + # sP is indexed as ((m, k%16), 0, ((k//16)%4, k//64)) with strides ((64,1),0,(16,8192)). + # In the P matrix's (m, k) coordinate space: + # sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64) + # This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192)) + _sP_nostage = sP[(None, None, None, 0)] # remove stage dim + + row_max = -Float32.inf + row_sum = Float32(0.0) + scale_log2 = Float32(self.scale_softmax_log2) + + # ============================================================ + # D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH + # ================================================= + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts data (ratio = -11 billion). + # Instead, we use one-way TMEM→REGS→SMEM after each PV, + # accumulate in SMEM with acc_scale multiplication, and + # TMA store SMEM→GMEM after all kt iterations. + # + # For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store + # path works perfectly (cos=0.999998). The SMEM accumulator + # is only needed for n_kv_tiles > 1. + # ============================================================ + + # NOTE: The code below is the BROKEN TMEM round-trip approach. + # It's kept as reference but should NOT be used. + # The SMEM accumulator implementation is TODO. + + # prev_acc_scale: unused, kept for clarity. acc_scale at kt is used + # to rescale O from kt=0..kt-1 before PV[kt]. + prev_acc_scale = Float32(0.0) + + for kt in range(self.n_kv_tiles): + si_handle = s_cons.wait_and_advance() + + tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) + cute.arch.fence_view_async_tmem_load() + + # D3/D4/D5c: In-kernel logit modification. + # After loading S from TMEM, modify logits for SWA positions: + # D5c: Add sink_bias (attn_sink) to positions >= n_comp + # D3: Mask positions >= n_comp + swa_len to -inf + # D4: Causal mask — SWA positions where k_coord > m_coord → -inf + # Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col). + # For kt > 0, absolute KV pos = kt*128 + k_coord. + if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias): + kt_offset = Int32(kt * 128) # KV position offset for this tile + # D5c: Read sink bias once (same for all positions in this head). + # Define unconditionally for CuTeDSL scoping (used when apply_sink_bias). + # The bias must be added in the SCALED-LOG2 domain: attn_sink * log2(e). + # But we add to the RAW logits before the scale_log2 multiply. + # Raw correction: attn_sink / scale → after * scale_log2 → attn_sink * log2(e) + sink_val = Float32(0.0) + if const_expr(self.apply_sink_bias): + sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax) + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] # query row position + k_coord = coord[1] # position within this KV tile + kv_pos = kt_offset + k_coord # absolute KV position + # D5c: Add sink bias to SWA positions (>= n_comp) + if const_expr(self.apply_sink_bias): + if kv_pos >= Int32(self.n_comp): + tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val + # D3: SWA length mask + should_mask = Boolean(0) + if const_expr(self.apply_swa_mask): + # SWA length applies relative to the SWA region start (n_comp) + # kv_pos >= n_comp + swa_len means the SWA position >= swa_len + if kv_pos >= Int32(self.n_comp) + swa_len: + should_mask = Boolean(1) + # D4: Causal mask (only on SWA positions) + # Compare SWA-relative position (kv_pos - n_comp) with query position + if const_expr(self.is_causal): + if kv_pos >= Int32(self.n_comp): + swa_pos = kv_pos - Int32(self.n_comp) + if swa_pos > m_coord: + should_mask = Boolean(1) + if should_mask: + tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf + + old_row_max = row_max + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) + + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = Float32(0.0) + + acc_scale_ = old_row_max - row_max_safe + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + if old_row_max == -cutlass.Float32.inf: + acc_scale = Float32(0.0) + row_sum *= acc_scale + + rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) + rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) + minus_row_max = Float32(0.0) - row_max_safe + + rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max + tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] + s_vec = tTMEM_LOADrS_frg[None, j].load() + rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) + + if not self.use_smem_p: + # TMEM-P: store P to TMEM via register bridge + cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + else: + # SMEM-P: write P to sP using coordinate-indexed store. + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + cute.arch.fence_proxy("async.shared", space="cta") + # D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED. + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts O accumulator data. + # Production path for multi-KV-tile: Python KV merge (cos 0.999998). + # Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt). + # n_kv_tiles=1 is the only supported path for in-kernel processing. + + si_handle.release() + softmax_done_bar.arrive() + + # Wait for MMA's PV[N-1] to commit before reading O. + final_o_bar.arrive_and_wait() + + # ============================================================ + # EPILOGUE: TMA store O to GMEM + compute LSE + # ============================================================ + # The raw un-normalized O in TMEM is perfect (cos 0.999998). + # We use epilogue_tma_store which reads O from TMEM directly via + # the correct get_tmem_load_op layout — no round-trip needed. + # + # For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures + # O is correctly rescaled before this epilogue reads it. + # + # External normalization (D5a path): kernel outputs un-normalized O + + # LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum. + # This is exact and composes with D5c sink bias merge. + # ============================================================ + + # TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM) + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, + 0, const_expr(lambda x: x), (0, 0, 0), + acc_cons_st, acc_pipe, c_pipe, + ) + c_pipe.producer_tail() + + # Compute LSE: lse = ln(row_sum) + row_max * ln(2) + # Only when emitting un-normalized output (D5a path). + # When normalize=True, LSE is not needed (in-kernel normalization). + # + # Per-row LSE: each softmax thread (sfw_idx 0..127) handles one row. + # sfw_idx maps directly to the row index in the attention matrix. + # All 128 threads write independently to mLSE[sfw_idx] — no sync needed. + if const_expr(not self.normalize): + _row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + _row_max_safe = Float32(0.0) + _ln2 = Float32(0.6931471805599453) # ln(2) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 + mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val + # Also output row_sum for external normalization (D5c) + mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr)