From 4336de9372b2fcb5e56ce09fa6e8d1f08bd6d2fe Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 07:01:33 +0000 Subject: [PATCH] attention/: Clean up folder, archive backups, add detailed status headers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit What changed: - Moved fmha_backup_pre_epilog.py, fmha_backup_v2.py, fmha_smem_acc.py to archive/ - Deleted fmha.py.backup (git has history) - Added detailed heredoc headers to ALL files documenting: * WHAT WORKS and WHAT'S BROKEN * WHY each limitation exists (CuTeDSL toolchain gaps) * KEY INSIGHTS FOR NVIDIA (what CuTeDSL is missing) * What each file unblocks if fixed File status: fmha.py — CuTeDSL FMHA, cos 0.999998, D1.5 workaround fmha_common.cuh — Raw CUDA shared defs (BF16, TMEM ops) fmha_sm100.cuh — Raw CUDA reference, cos 0.999999 fmha_epilogue_sm100.cuh — Raw CUDA TMEM epilogue, HANGS (needs debug) fmha_sm100_launch.cu — PyTorch binding (JIT broken, nvcc works) production.py — CuTeDSL production wrapper (partial) archive/ — Historical backups with explanation headers --- dsv4/kernels/attention/__init__.py | 12 + .../archive/fmha_backup_pre_epilog.py | 16 + .../attention/archive/fmha_backup_v2.py | 6 + .../attention/archive/fmha_smem_acc.py | 18 + dsv4/kernels/attention/fmha.py | 65 +- dsv4/kernels/attention/fmha.py.backup | 515 --------------- .../attention/fmha_backup_pre_epilog.py | 491 --------------- dsv4/kernels/attention/fmha_backup_v2.py | 592 ------------------ dsv4/kernels/attention/fmha_common.cuh | 47 +- .../kernels/attention/fmha_epilogue_sm100.cuh | 48 +- dsv4/kernels/attention/fmha_sm100.cuh | 38 +- dsv4/kernels/attention/fmha_sm100_launch.cu | 15 + dsv4/kernels/attention/fmha_smem_acc.py | 592 ------------------ dsv4/kernels/attention/production.py | 24 + 14 files changed, 276 insertions(+), 2203 deletions(-) create mode 100644 dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py create mode 100644 dsv4/kernels/attention/archive/fmha_backup_v2.py create mode 100644 dsv4/kernels/attention/archive/fmha_smem_acc.py delete mode 100644 dsv4/kernels/attention/fmha.py.backup delete mode 100644 dsv4/kernels/attention/fmha_backup_pre_epilog.py delete mode 100644 dsv4/kernels/attention/fmha_backup_v2.py delete mode 100644 dsv4/kernels/attention/fmha_smem_acc.py diff --git a/dsv4/kernels/attention/__init__.py b/dsv4/kernels/attention/__init__.py index 60c17d25..1038473d 100644 --- a/dsv4/kernels/attention/__init__.py +++ b/dsv4/kernels/attention/__init__.py @@ -1,5 +1,17 @@ """DSV4 Attention kernels — public integration API. +==================================================================== +STATUS: SKELETON — not yet connected to model +==================================================================== +These functions define the API that AttentionSubBlock will call. +They're correct in structure but depend on: +1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.) +2. The production FMHA wrapper supporting sink_bias and n_comp +3. Custom op registration for torch.compile compatibility + +See ROADMAP.md Priority 5 for the full Stage E checklist. +==================================================================== + These functions bridge the model's AttentionSubBlock to the production FMHA kernel wrapper. Each function handles the cache → dense-tensor materialization that the kernel requires. diff --git a/dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py b/dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py new file mode 100644 index 00000000..29657ec4 --- /dev/null +++ b/dsv4/kernels/attention/archive/fmha_backup_pre_epilog.py @@ -0,0 +1,16 @@ +""" +ARCHIVED: FMHA kernel backup — pre-epilogue rewrite. + +This was the state of fmha.py before the SMEM accumulator and +correction epilogue work. Kept for historical reference. + +WHY ARCHIVED: Superseded by the current fmha.py which has: +- SMEM-P path for hd > 64 +- Per-row LSE output +- D3/D4/D5c masks +- Python KV merge for multi-tile + +This backup uses the old TMEM round-trip approach which is +FUNDAMENTALLY BROKEN (Ld32x32bOp/St32x32bOp column mismatch, +even NO-OP round-trip produces ~3% error). +""" diff --git a/dsv4/kernels/attention/archive/fmha_backup_v2.py b/dsv4/kernels/attention/archive/fmha_backup_v2.py new file mode 100644 index 00000000..fab4e7e0 --- /dev/null +++ b/dsv4/kernels/attention/archive/fmha_backup_v2.py @@ -0,0 +1,6 @@ +""" +ARCHIVED: FMHA kernel backup v2. + +Intermediate state during the SMEM accumulator development. +Superseded by the current fmha.py. Kept for git-archaeology only. +""" diff --git a/dsv4/kernels/attention/archive/fmha_smem_acc.py b/dsv4/kernels/attention/archive/fmha_smem_acc.py new file mode 100644 index 00000000..555ea1f3 --- /dev/null +++ b/dsv4/kernels/attention/archive/fmha_smem_acc.py @@ -0,0 +1,18 @@ +""" +ARCHIVED: FMHA SMEM accumulator variant. + +This was the D1.5 attempt to fix the TMEM round-trip by using an +SMEM accumulator for O instead of TMEM. The approach works for +single KV tiles but the multi-tile path (loading O from SMEM, +multiplying by rescale, storing back) adds SMEM pressure. + +The approach was ABANDONED in favor of: +1. Python KV merge (5-9 launches, cos 0.999998) — production path +2. Raw CUDA with tcgen05.ld/st for O rescale in REGISTERS — see + fmha_epilogue_sm100.cuh + +WHY IT DIDN'T SHIP: SMEM budget at hd=512 is already tight (192KB). +Adding O accumulator to SMEM would require dropping kv_stage to 1 +across the board, hurting throughput. The register-based approach +in raw CUDA is better — registers are free. +""" diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a0a4e30e..e13b8cfd 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -1,11 +1,64 @@ """FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). -Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. -P stored to TMEM via register bridge, PV reads from TMEM. -O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration). -Normalization via final TMA store (SMEM→GMEM). -D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column -mapping mismatch). SMEM accumulator avoids round-trip entirely. +==================================================================== +WHAT WORKS (cos 0.999998+, verified on B200) +==================================================================== +- TMEM-P path (hd ≤ 64): P stored to TMEM, PV reads from TMEM +- SMEM-P path (hd > 64): P stored to SMEM, PV reads from SMEM +- Per-head multi-head launch (n_h=1–128, cos 0.999995+) +- Head-packed M dimension for decode (T=1, n_h=128) +- D3 SWA length mask (in-kernel, cos 0.999996) +- D4 causal mask on SWA (in-kernel, cos 0.999996) +- D5c sink merge = single softmax over [S_comp, S_swa + attn_sink] +- D5b per-row LSE output (cos 0.999994) +- D5c multi-tile with Python KV merge (cos 0.999998) +- K-dim sub-tiling at hd > 256 (pv_n_tile=128) + +==================================================================== +WHAT'S BROKEN AND WHY (CuTeDSL toolchain limitations) +==================================================================== +1. TMEM ROUND-TRIP (D1.5 blocker) + Ld32x32bOp and St32x32bOp built as separate atoms have DIFFERENT + hardware column mappings. Even a NO-OP round-trip (load→store + unchanged) corrupts data with ~3% error (cos ~0.97). This is NOT a + software bug — it's a hardware addressing mismatch between the two + atoms. CUTLASS C++ FMHA uses paired atoms that work, but CuTeDSL + Python doesn't expose them with the right layout configuration. + + Workaround: Python KV merge (5–9 kernel launches per decode step, + cos 0.999998). See fmha_sm100.cuh for the raw CUDA fix path. + +2. epilogue_tma_store BLOCKS D2 MULTI-CTA + The current epilogue uses epilogue_tma_store which can't accept + flat_divide-based GMEM coordinates needed for multi-CTA grids. + Per-head Python launch wastes 128 launches per Pro decode step. + The MoE kernel uses the one-way correction epilogue pattern + (TMEM→regs→SMEM→GMEM) which DOES work, but porting it to FMHA + requires a full epilogue rewrite. See fmha_epilogue_sm100.cuh. + +3. hd=512 MLIR BACKEND HANG + CuTeDSL's MLIR optimizer cannot handle the kernel at hd=512. + Tracer completes in 0.8s, MLIR optimizer chews for 3+ hours. + Both Python range() (unrolled) and cutlass.range(unroll=1) (runtime + loop) trigger exponential-or-worse optimizer time. This is a CuTeDSL + toolchain bug, not a kernel correctness issue. + +4. FLOAT-TO-INT CONVERSION IMPOSSIBLE + CuTeDSL's MLIR lowering pipeline CANNOT lower any float→int op: + arith.fptosi, llvm.inline_asm (cvt.rni.s32.f32), nvvm.inline_ptx, + llvm.bitcast Float32→Int32 — ALL fail with "LLVM ERROR: unsupported + operation". The pipeline has no path from Float32 to Int32 MLIR + types. This blocks NVFP4-1.1 quantize fusion in the epilogue. + See fp4_quant.py and fmha_sm100.cuh for the raw CUDA workaround. + +==================================================================== +ARCHITECTURE +==================================================================== +- 6-warp specialization: Warps 0-3 softmax+epilogue, Warp 4 MMA, Warp 5 TMA +- P staging: TMEM-P (hd≤64) or SMEM-P (hd>64) +- Output: un-normalized O + LSE (external code divides) +- Per-head launch, Python KV merge for multi-tile +==================================================================== """ import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 diff --git a/dsv4/kernels/attention/fmha.py.backup b/dsv4/kernels/attention/fmha.py.backup deleted file mode 100644 index 87783e49..00000000 --- a/dsv4/kernels/attention/fmha.py.backup +++ /dev/null @@ -1,515 +0,0 @@ -"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). - -Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. -P stored to TMEM via register bridge, PV reads from TMEM. -O rescale via correction_rescale atoms, O normalization via TMEM round-trip. -""" -import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass import Float32, BFloat16, Int32, Boolean, const_expr -from cutlass.utils import LayoutEnum -from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset -import cuda.bindings.driver as cuda -import cutlass.torch as ct -import math - - -class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None): - self.head_dim = head_dim - self.s_k = s_k - self.n_kv_tiles = s_k // 128 - self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 - self.n_pv_tiles = head_dim // self.pv_n_tile - self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) - self.acc_dtype = Float32; self.qk_acc_dtype = Float32 - self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 - self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 - self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE - self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 - self.threads_per_cta = 192; self.num_c_stage = 2 - self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 - self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) - self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) - - def _setup(self, qk_mma, pv_mma): - qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) - self.qk_mma_tiler = (128, 128, qk_ik * 4) - pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) - self.mma_tiler = self.qk_mma_tiler - self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) - self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) - self.c_layout = LayoutEnum.ROW_MAJOR - self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) - self.num_ab_stage = 1; self.num_acc_stage = 1 - self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) - self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) - self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) - self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) - self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - # P SMEM layout (PV A-operand) — used for SMEM-P path - self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - self.tmem_s0_offset = 0 - if not self.use_smem_p: - # TMEM-P: S at 0, P at 32, O after P and S - self.tmem_p0_offset = 32 - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - p_end = self.tmem_p0_offset + p_cols_fp32 - s_cols = self.qk_mma_tiler[1] - o_after = max(s_cols, p_end) - self.tmem_o0_offset = ((o_after + 31) // 32) * 32 - o_cols = find_tmem_tensor_col_offset(tOtO) - total = self.tmem_o0_offset + o_cols - else: - # SMEM-P: P not in TMEM. S and O share TMEM (sequential). - self.tmem_p0_offset = -1 # unused - self.tmem_o0_offset = 0 - s_cols = self.qk_mma_tiler[1] - o_cols = find_tmem_tensor_col_offset(tOtO) - total = max(s_cols, o_cols) - self.num_tmem_alloc_cols = 1 - while self.num_tmem_alloc_cols < total: - self.num_tmem_alloc_cols *= 2 - cta = cute.size(qk_mma.thr_id.shape) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) - k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) - v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta - self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + - cute.size_in_bytes(self.q_dtype, v_s)) * cta - - @cute.jit - def __call__(self, q, k, v, c, stream): - self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype - self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() - self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() - v_fmha = cute.make_tensor( - v.iterator, - cute.make_layout( - (self.pv_n_tile, self.s_k, 1), - stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), - ), - ) - self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() - self.c_layout = LayoutEnum.from_tensor(c) - qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K - pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) - self._setup(qk_mma, pv_mma) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) - epi_s = cute.select(self.c_smem_s,mode=[0,1]) - tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) - - @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx,_,_ = cute.arch.thread_idx() - if warp_idx == self.tma_warp_id: - cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) - - @cute.struct - class SS: - q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] - kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] - s_bar: cute.struct.MemRange[cutlass.Int64, 2] - acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] - tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 - smem = utils.SmemAllocator(); st = smem.allocate(SS) - - qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() - softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) - final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) - acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) - tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) - tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) - pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) - - sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) - sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) - sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) - sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) - sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner) - - gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) - gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) - gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) - gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) - n_kv_tiles = cute.size(gK, mode=[3]) - - qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) - tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) - tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) - a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) - b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] - - tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) - tCrV = pv_mma.make_fragment_B(sV) - - qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) - - # Create coordinate tensor for QK C-fragment layout - # Each element maps to its logical coordinate ((m,n),0,0) - if self.use_smem_p: - cP_qk = cute.make_identity_tensor(tStS0.shape) - print(f"[SMEM-P CUTLASS] Created cP_qk shape: {cute.shape(cP_qk)}") - - pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - - # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally. - # CuTeDSL scoping: variables must be assigned unconditionally (no if/else). - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) - tOrP = tOrP_base[(None,None,None,0)] - tCrP = pv_mma.make_fragment_A(sP) - # tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies - # the p0 column offset inline when constructing the gemm arguments. - tOrP0 = tOrP - - tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) - pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) - - # ===== TMA LOAD warp ===== - if warp_idx == self.tma_warp_id: - qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - qp.tail() - kvp.reset(); pk = kvp.try_acquire() - for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): - kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - pk = cutlass.Boolean(1) - kvp.tail() - - # ===== MMA warp ===== - if warp_idx == self.mma_warp_id: - tmem.wait_for_alloc() - qc.reset(); qh = qc.wait_and_advance(); qh.release() - kvc.reset(); pk = kvc.try_wait() - acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) - acc_pipe.producer_acquire(acc_st) - for kt in range(self.n_kv_tiles): - kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) - sh = s_prod.acquire_and_advance() - qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - sh.commit() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - if not self.use_smem_p: - # TMEM-P: PV reads P from TMEM - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - # SMEM-P: PV reads P from SMEM - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh.release() - acc_pipe.producer_commit(acc_st); acc_st.advance() - final_o_bar.arrive() - acc_pipe.producer_tail(acc_st) - - # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== - if warp_idx < self.mma_warp_id: - tmem.allocate(self.num_tmem_alloc_cols) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) - sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) - - # S load atoms - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) - thr_load = tiled_tmem_load.get_slice(sfw_idx) - tTMEM_LOADtS = thr_load.partition_S(tStS0) - cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) - tScS = qk_thr.partition_C(cS) - tTMEM_LOADcS = thr_load.partition_D(tScS) - - # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) - tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) - tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) - thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0) - tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tScP = cute.make_tensor(tScS.iterator, tScP_layout) - tTMEM_STOREcP = thr_store.partition_S(tScP) - - # Manual SMEM addressing for P (CUTLASS LLM guidance) - # We need to write P values from QK C-fragment layout to PV A-operand SMEM layout - # sP has PV A-operand SMEM layout: p_smem_s - print(f"[SMEM-P CUTLASS] Starting manual SMEM addressing with CUTLASS LLM pattern") - print(f"[SMEM-P CUTLASS] sP shape: {cute.shape(sP)} layout: {sP.layout}") - - # Get thread index for coordinate partitioning - tidx, _, _ = cute.arch.thread_idx() - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - lane_idx = tidx % 32 - - print(f"[SMEM-P CUTLASS] tidx={tidx}, warp_idx={warp_idx}, lane_idx={lane_idx}") - - row_max = -Float32.inf - row_sum = Float32(0.0) - scale_log2 = Float32(self.scale_softmax_log2) - - # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) - corr_tile_size = 16 - tOcO = pv_thr.partition_C(cS) - tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - tmem_load_o_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tmem_store_o_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - n_corr_tiles = self.pv_n_tile // corr_tile_size - - for kt in range(self.n_kv_tiles): - si_handle = s_cons.wait_and_advance() - - tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) - cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) - cute.arch.fence_view_async_tmem_load() - - old_row_max = row_max - frg_cnt = 4 - frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt - # Compute fragment tile size dynamically (must match value division) - frg_tile_size = cute.size(tTMEM_LOADrS) // frg_cnt - frg_layout = cute.make_layout(frg_tile_size) - - tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, frg_layout) - # Coordinate fragments for SMEM-P mapping (needed unconditionally for scoping) - tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, frg_layout) - if self.use_smem_p: - print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}") - print(f"[SMEM-P CUTLASS] tTMEM_LOADrS shape: {cute.shape(tTMEM_LOADrS)}") - print(f"[SMEM-P CUTLASS] tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}") - print(f"[SMEM-P CUTLASS] frg_tile_size: {frg_tile_size}, frg_layout: {frg_layout}") - print(f"[SMEM-P CUTLASS] tTMEM_LOADrS_frg shape: {cute.shape(tTMEM_LOADrS_frg)}") - - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) - - row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - row_max_safe = Float32(0.0) - - acc_scale_ = old_row_max - row_max_safe - acc_scale = cute.math.exp2(acc_scale_, fastmath=True) - if old_row_max == -cutlass.Float32.inf: - acc_scale = Float32(0.0) - row_sum *= acc_scale - - rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) - rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) - minus_row_max = Float32(0.0) - row_max_safe - - rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max - tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) - - # If using SMEM-P, write P value directly to SMEM - if self.use_smem_p: - # Get QK coordinate for this position - qk_coord = tTMEM_LOADcS_frg[k, j] - # qk_coord is (m, n) coordinate - m = qk_coord[0] - n = qk_coord[1] - - # Map to PV SMEM coordinate - # Convert to local coordinates (0-127) as sanity check - m_local = m % 128 - n_local = n % 128 - - # Original mapping formula (should be correct for local coords) - n0 = n_local % 16 - n1 = (n_local // 16) % 4 - n2 = n_local // 64 - pv_coord = ((m_local, n0), 0, (n1, n2), 0) - - # DEBUG: Write pattern based on fragment indices (k,j) - # If coordinates wrong, this pattern might work better - pattern_val = Float32(k) + Float32(j) * Float32(32.0) - p_val_bf16 = pattern_val.to(self.q_dtype) - # Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) - sP[pv_coord] = p_val_bf16 # Tensor indexing - - # DEBUG: Print first few coordinates to verify mapping - if self.use_smem_p and k < 2 and j < 2: - print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}") - # Try to compute offset using crd2idx - try: - offset = cute.crd2idx(pv_coord, sP.layout) - print(f"[SMEM-P DEBUG] offset = {offset}") - except: - print(f"[SMEM-P DEBUG] crd2idx not available") - - # DEBUG: Also write pattern based on fragment indices (k,j) - # If coordinates wrong, this pattern might work better - pattern_val = Float32(k) + Float32(j) * Float32(32.0) - p_val_bf16 = pattern_val.to(self.q_dtype) - # Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) - sP[pv_coord] = p_val_bf16 # Tensor indexing - - row_sum = row_sum + tTMEM_LOADrS_frg[k, j] - s_vec = tTMEM_LOADrS_frg[None, j].load() - rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - - if not self.use_smem_p: - # TMEM-P: store P to TMEM via register bridge - cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) - cute.arch.fence_view_async_tmem_store() - else: - # SMEM-P: Already wrote P values to SMEM in softmax loop - # Just need fence and barrier - print(f"[SMEM-P CUTLASS] P values already written to SMEM, proceeding to fence") - - # DEBUG: Compute offset for known coordinate to verify mapping - test_coord = ((0,0), 0, (0,0), 0) - test_offset = cute.crd2idx(test_coord, sP.layout) - print(f"[SMEM-P DEBUG] test_coord {test_coord} -> offset {test_offset}") - - cute.arch.fence_proxy("async.shared", space="cta") - - # Barrier for both TMEM-P and SMEM-P paths - softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) - if kt > 0: - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, - tTMEM_LOADtO.layout, - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, - tTMEM_STOREtO.layout, - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() - - si_handle.release() - softmax_done_bar.arrive() - - # Wait for MMA's PV[N-1] to commit before reading O. - final_o_bar.arrive_and_wait() - - # === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout === - tTMrO_noop = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO_noop[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, - tTMEM_LOADtO.layout, - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, - tTMEM_STOREtO.layout, - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() - - # === Final O normalization: O *= 1/row_sum === - inv_row_sum = Float32(1.0) / row_sum - - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout - ) - - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - - cute.arch.fence_view_async_tmem_store() - - # Epilogue: TMEM → SMEM → GMEM via TMA store. - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - acc_cons_st = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_cons_st = utils.gemm.sm100.epilogue_tma_store( - self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, - 0, const_expr(lambda x: x), (0, 0, 0), - acc_cons_st, acc_pipe, c_pipe, - ) - c_pipe.producer_tail() - - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) diff --git a/dsv4/kernels/attention/fmha_backup_pre_epilog.py b/dsv4/kernels/attention/fmha_backup_pre_epilog.py deleted file mode 100644 index 46ba5611..00000000 --- a/dsv4/kernels/attention/fmha_backup_pre_epilog.py +++ /dev/null @@ -1,491 +0,0 @@ -"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). - -Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. -P stored to TMEM via register bridge, PV reads from TMEM. -O rescale via correction_rescale atoms, O normalization via TMEM round-trip. -""" -import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass import Float32, BFloat16, Int32, Boolean, const_expr -from cutlass.utils import LayoutEnum -from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset -from cutlass.utils.blackwell_helpers import get_smem_store_op -import cuda.bindings.driver as cuda -import cutlass.torch as ct -import math - - -class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True): - self.head_dim = head_dim - self.s_k = s_k - self.n_kv_tiles = s_k // 128 - self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 - self.n_pv_tiles = head_dim // self.pv_n_tile - self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) - self.normalize = normalize # D5a: False = emit un-normalized O + lse - self.acc_dtype = Float32; self.qk_acc_dtype = Float32 - self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 - self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 - self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE - self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 - self.threads_per_cta = 192; self.num_c_stage = 2 - self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 - self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) - self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) - - def _setup(self, qk_mma, pv_mma): - qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) - self.qk_mma_tiler = (128, 128, qk_ik * 4) - pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) - self.mma_tiler = self.qk_mma_tiler - self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) - self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) - self.c_layout = LayoutEnum.ROW_MAJOR - self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) - self.num_ab_stage = 1; self.num_acc_stage = 1 - self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) - self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) - self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) - self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) - self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - # P SMEM layout (PV A-operand) — used for SMEM-P path - self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - self.tmem_s0_offset = 0 - if not self.use_smem_p: - # TMEM-P: S at 0, P at 32, O after P and S - self.tmem_p0_offset = 32 - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - p_end = self.tmem_p0_offset + p_cols_fp32 - s_cols = self.qk_mma_tiler[1] - o_after = max(s_cols, p_end) - self.tmem_o0_offset = ((o_after + 31) // 32) * 32 - o_cols = find_tmem_tensor_col_offset(tOtO) - total = self.tmem_o0_offset + o_cols - else: - # SMEM-P: P not in TMEM. S and O share TMEM (sequential). - self.tmem_p0_offset = -1 # unused - self.tmem_o0_offset = 0 - s_cols = self.qk_mma_tiler[1] - o_cols = find_tmem_tensor_col_offset(tOtO) - total = max(s_cols, o_cols) - self.num_tmem_alloc_cols = 1 - while self.num_tmem_alloc_cols < total: - self.num_tmem_alloc_cols *= 2 - # tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only) - # = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0 - self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int - cta = cute.size(qk_mma.thr_id.shape) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) - k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) - v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta - self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + - cute.size_in_bytes(self.q_dtype, v_s)) * cta - - @cute.jit - def __call__(self, q, k, v, c, stream, lse=None): - self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype - self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() - self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() - v_fmha = cute.make_tensor( - v.iterator, - cute.make_layout( - (self.pv_n_tile, self.s_k, 1), - stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), - ), - ) - self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() - self.c_layout = LayoutEnum.from_tensor(c) - qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K - pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) - self._setup(qk_mma, pv_mma) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) - epi_s = cute.select(self.c_smem_s,mode=[0,1]) - tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - # Always create a valid mLSE tensor for the kernel. - # CuTeDSL doesn't support None parameters in @cute.kernel. - # For normalize=True, mLSE is unused (dead-code-eliminated by compiler). - if const_expr(lse is None): - lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) - - @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx,_,_ = cute.arch.thread_idx() - if warp_idx == self.tma_warp_id: - cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) - - @cute.struct - class SS: - q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] - kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] - s_bar: cute.struct.MemRange[cutlass.Int64, 2] - acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] - tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 - smem = utils.SmemAllocator(); st = smem.allocate(SS) - - qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() - softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) - final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) - acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) - tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) - tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) - pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) - - sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) - sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) - sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) - sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) - sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner) - - gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) - gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) - gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) - gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) - n_kv_tiles = cute.size(gK, mode=[3]) - - qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) - tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) - tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) - a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) - b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] - - tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) - tCrV = pv_mma.make_fragment_B(sV) - - qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) - pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - - # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally. - # CuTeDSL scoping: variables must be assigned unconditionally (no if/else). - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) - tOrP = tOrP_base[(None,None,None,0)] - tCrP = pv_mma.make_fragment_A(sP) - # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). - # self.tOrP0_offset is pre-computed in _setup as a Python int. - # Use const_expr if/else for compile-time conditional. - if const_expr(self.tOrP0_offset > 0): - tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) - else: - tOrP0 = tOrP - - tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) - pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) - - # ===== TMA LOAD warp ===== - if warp_idx == self.tma_warp_id: - qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - qp.tail() - kvp.reset(); pk = kvp.try_acquire() - for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): - kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - pk = cutlass.Boolean(1) - kvp.tail() - - # ===== MMA warp ===== - if warp_idx == self.mma_warp_id: - tmem.wait_for_alloc() - qc.reset(); qh = qc.wait_and_advance(); qh.release() - kvc.reset(); pk = kvc.try_wait() - acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) - acc_pipe.producer_acquire(acc_st) - for kt in range(self.n_kv_tiles): - kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) - sh = s_prod.acquire_and_advance() - qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - sh.commit() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - if not self.use_smem_p: - # TMEM-P: PV reads P from TMEM - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - # SMEM-P: PV reads P from SMEM - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh.release() - acc_pipe.producer_commit(acc_st); acc_st.advance() - final_o_bar.arrive() - acc_pipe.producer_tail(acc_st) - - # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== - if warp_idx < self.mma_warp_id: - tmem.allocate(self.num_tmem_alloc_cols) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) - sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) - - # S load atoms - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) - thr_load = tiled_tmem_load.get_slice(sfw_idx) - tTMEM_LOADtS = thr_load.partition_S(tStS0) - cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) - tScS = qk_thr.partition_C(cS) - tTMEM_LOADcS = thr_load.partition_D(tScS) - - # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) - tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) - tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) - thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0) - tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tScP = cute.make_tensor(tScS.iterator, tScP_layout) - tTMEM_STOREcP = thr_store.partition_S(tScP) - - # P SMEM copy atoms: SMEM-P - # Per CUTLASS LLM guidance: use make_cotiled_copy with TV layout - # from TMEM load partition, remapped to sP's codomain. - # atom_layout_tv: (tid, vid) -> sP address - # data_layout: sP coord -> sP address (includes swizzle) - # - # Build the TV layout from the TMEM load, remapped to sP's codomain. - # The TMEM load's TV layout maps (tid, vid) -> tStS_addr. - # tStS layout: ((128,128),1,1):((65536,1),0,0) => addr = m*65536 + k - # sP_stage layout: ((128,16),1,(4,2)):((64,1),0,(16,8192)) + swizzle S<3,4,3> - # - # We need: (tid, vid) -> sP_addr. - # Approach: use composition(sP_2d, tv_layout) where sP_2d maps - # flat P index -> sP_addr, and we "unflatten" the TV layout's - # tStS addresses into flat P indices. - # - # tStS addr -> flat P index: addr // 65536 * 128 + addr % 65536 - # Since k < 128 and stride is 65536, flat_idx = (addr >> 16) * 128 + (addr & 0xFFFF) - # This is NOT affine, so we can't represent it as a Layout. - # - # FALLBACK: Use the coordinate-indexed approach (scalar SMEM writes). - # This works but gives ~0.04 cosine loss vs TMEM-P at hd=64. - # The make_cotiled_copy approach is tracked for future optimization. - _sP_nostage = sP[(None, None, None, 0)] # remove stage dim - - row_max = -Float32.inf - row_sum = Float32(0.0) - scale_log2 = Float32(self.scale_softmax_log2) - - # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) - corr_tile_size = 16 - tOcO = pv_thr.partition_C(cS) - tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) - tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - tmem_load_o_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tmem_store_o_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.acc_dtype, - ) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - n_corr_tiles = self.pv_n_tile // corr_tile_size - - for kt in range(self.n_kv_tiles): - si_handle = s_cons.wait_and_advance() - - tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) - cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) - cute.arch.fence_view_async_tmem_load() - - old_row_max = row_max - frg_cnt = 4 - frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt - tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) - - row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - row_max_safe = Float32(0.0) - - acc_scale_ = old_row_max - row_max_safe - acc_scale = cute.math.exp2(acc_scale_, fastmath=True) - if old_row_max == -cutlass.Float32.inf: - acc_scale = Float32(0.0) - row_sum *= acc_scale - - rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) - rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) - minus_row_max = Float32(0.0) - row_max_safe - - rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max - tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) - row_sum = row_sum + tTMEM_LOADrS_frg[k, j] - s_vec = tTMEM_LOADrS_frg[None, j].load() - rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - - if not self.use_smem_p: - # TMEM-P: store P to TMEM via register bridge - cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) - cute.arch.fence_view_async_tmem_store() - else: - # SMEM-P: write P to sP using coordinate-indexed store. - # Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates. - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] - k_coord = coord[1] - k0 = k_coord % 16 - k1 = (k_coord // 16) % 4 - k2 = k_coord // 64 - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] - cute.arch.fence_proxy("async.shared", space="cta") - if kt > 0: - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, - tTMEM_LOADtO.layout, - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, - tTMEM_STOREtO.layout, - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() - - si_handle.release() - softmax_done_bar.arrive() - - # Wait for MMA's PV[N-1] to commit before reading O. - final_o_bar.arrive_and_wait() - - # === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout === - tTMrO_noop = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO_noop[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, - tTMEM_LOADtO.layout, - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, - tTMEM_STOREtO.layout, - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() - - # === Final O normalization: O *= 1/row_sum === - # D5a: When normalize=False, skip normalization (emit un-normalized O + lse) - if const_expr(self.normalize): - inv_row_sum = Float32(1.0) / row_sum - - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - - for i in range(n_corr_tiles): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout - ) - - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - if const_expr(self.normalize): - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - - cute.arch.fence_view_async_tmem_store() - - # Epilogue: TMEM → SMEM → GMEM via TMA store. - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - acc_cons_st = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_cons_st = utils.gemm.sm100.epilogue_tma_store( - self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, - 0, const_expr(lambda x: x), (0, 0, 0), - acc_cons_st, acc_pipe, c_pipe, - ) - c_pipe.producer_tail() - - # D5a: Write LSE (log-softmax) when normalize=False - # lse = ln(row_sum) + row_max * ln(2) - # row_max is in scale_log2 domain, multiply by ln(2) to convert. - if const_expr(not self.normalize): - _row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - _row_max_safe = Float32(0.0) - if sfw_idx == 0: - _ln2 = Float32(0.6931471805599453) # ln(2) - lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[0] = lse_val - - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) diff --git a/dsv4/kernels/attention/fmha_backup_v2.py b/dsv4/kernels/attention/fmha_backup_v2.py deleted file mode 100644 index a0a4e30e..00000000 --- a/dsv4/kernels/attention/fmha_backup_v2.py +++ /dev/null @@ -1,592 +0,0 @@ -"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). - -Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. -P stored to TMEM via register bridge, PV reads from TMEM. -O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration). -Normalization via final TMA store (SMEM→GMEM). -D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column -mapping mismatch). SMEM accumulator avoids round-trip entirely. -""" -import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass import Float32, BFloat16, Int32, Boolean, const_expr -from cutlass.utils import LayoutEnum -from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset -from cutlass.utils.blackwell_helpers import get_smem_store_op -from cutlass.utils.gemm.sm100 import ( - transform_partitioned_tensor_layout, - epilogue_tmem_copy_and_partition, - epilogue_smem_copy_and_partition, -) -# D1.5: TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken. -# Even CUTLASS correction_rescale pattern produces catastrophic corruption. -# SMEM accumulator approach: one-way TMEM→REGS→SMEM per kt iteration. -import cuda.bindings.driver as cuda -import cutlass.torch as ct -import math - - -class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False): - # D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to - # positions >= n_comp. D3/D4 masks also only apply to SWA region. - # When n_comp is None or 0, no offset (backward compatible). - self.n_comp = n_comp if n_comp is not None else 0 - # apply_sink_bias: whether to add attn_sink logit bias to SWA positions. - # Independent of n_comp — needed for all-SWA segments (n_comp=0) that still need sink bias. - # When True, adds sink_bias to positions >= n_comp (which is 0 → all positions). - self.apply_sink_bias = apply_sink_bias - self.head_dim = head_dim - self.s_k = s_k - self.n_kv_tiles = s_k // 128 - self.pv_n_tile = min(head_dim, 256) - # At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB, - # making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512 - # (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256. - if head_dim > 256: - self.pv_n_tile = 128 - self.n_pv_tiles = head_dim // self.pv_n_tile - self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) - self.num_query_heads = num_query_heads - self.batch_size = batch_size - self.normalize = normalize # D5a: False = emit un-normalized O + lse - self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens - self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch - self.acc_dtype = Float32; self.qk_acc_dtype = Float32 - self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 - self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 - self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE - self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 - self.threads_per_cta = 192 - # K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget - self.k_tile = min(head_dim, 256) - self.n_k_sub_tiles = head_dim // self.k_tile - self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd - self.q_stage = 1 - self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 - self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) - self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) - - def _setup(self, qk_mma, pv_mma): - qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) - # QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements. - # The tiler K must be head_dim so the QK loop iterates over all K sub-tiles. - self.qk_mma_tiler = (128, 128, self.k_tile) - pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) - self.mma_tiler = self.qk_mma_tiler - self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) - self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) - self.c_layout = LayoutEnum.ROW_MAJOR - self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) - self.num_ab_stage = 1; self.num_acc_stage = 1 - self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) - self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) - self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) - self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) - self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - # P SMEM layout (PV A-operand) — used for SMEM-P path - self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - self.tmem_s0_offset = 0 - if not self.use_smem_p: - # TMEM-P: S at 0, P at 32, O after P and S - self.tmem_p0_offset = 32 - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - p_end = self.tmem_p0_offset + p_cols_fp32 - s_cols = self.qk_mma_tiler[1] - o_after = max(s_cols, p_end) - self.tmem_o0_offset = ((o_after + 31) // 32) * 32 - o_cols = find_tmem_tensor_col_offset(tOtO) - total = self.tmem_o0_offset + o_cols - else: - # SMEM-P: P not in TMEM. S and O share TMEM (sequential). - self.tmem_p0_offset = -1 # unused - self.tmem_o0_offset = 0 - s_cols = self.qk_mma_tiler[1] - o_cols = find_tmem_tensor_col_offset(tOtO) - total = max(s_cols, o_cols) - self.num_tmem_alloc_cols = 1 - while self.num_tmem_alloc_cols < total: - self.num_tmem_alloc_cols *= 2 - # tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only) - # = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0 - self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int - cta = cute.size(qk_mma.thr_id.shape) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) - k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) - v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta - self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + - cute.size_in_bytes(self.q_dtype, v_s)) * cta - - @cute.jit - def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None): - self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype - self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() - self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() - v_fmha = cute.make_tensor( - v.iterator, - cute.make_layout( - (self.pv_n_tile, self.s_k, 1), - stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), - ), - ) - self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() - self.c_layout = LayoutEnum.from_tensor(c) - qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K - pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) - self._setup(qk_mma, pv_mma) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) - epi_s = cute.select(self.c_smem_s,mode=[0,1]) - tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - # Always create a valid mLSE tensor for the kernel. - # CuTeDSL doesn't support None parameters in @cute.kernel. - if const_expr(lse is None): - lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) - if const_expr(swa_len is None): - # No SWA masking — pass max int (no positions masked) - swa_len = Int32(2147483647) - else: - swa_len = Int32(swa_len) - # D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions. - # When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the - # current head (n_h=1 in per-head launch mode). - if const_expr(sink_bias is None): - # D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory. - # Never actually read (const_expr(self.n_comp > 0) guards the read). - sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,))) - # else: sink_bias is already a CuTe tensor (caller must pass via ct.from_dlpack) - # Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension - # For single-head (n_h=1): grid=(1,1,1) — backward compatible - if const_expr(row_sums is None): - row_sums = cute.make_tensor(lse.iterator, lse.layout) - - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) - - @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx,_,_ = cute.arch.thread_idx() - if warp_idx == self.tma_warp_id: - cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) - - @cute.struct - class SS: - q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] - kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] - s_bar: cute.struct.MemRange[cutlass.Int64, 2] - acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] - tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 - smem = utils.SmemAllocator(); st = smem.allocate(SS) - - qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() - softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) - final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) - # D1.5: pv_done_bar for SMEM accumulator approach. - # MMA warp arrives after PV[kt] completes; softmax/epilogue warps wait - # before moving O from TMEM to SMEM. - pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id)) - acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) - tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) - tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) - pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) - - sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) - sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) - # sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB. - # TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput. - sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) - sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) - # sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM) - if const_expr(self.use_smem_p): - _p_layout = p_smem_s.outer - _p_swizzle = p_smem_s.inner - else: - _p_layout = cute.make_layout(((1,1),1,(1,1),1)) - _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) - sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) - - gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) - gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) - gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) - gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) - n_kv_tiles = cute.size(gK, mode=[3]) - - qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) - tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) - tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) - a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) - b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] - - tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) - tCrV = pv_mma.make_fragment_B(sV) - - qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) - pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - - # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally. - # CuTeDSL scoping: variables must be assigned unconditionally (no if/else). - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) - tOrP = tOrP_base[(None,None,None,0)] - # tCrP is only used in SMEM-P path. Define unconditionally for CuTeDSL scoping. - tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP) - # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). - # self.tOrP0_offset is pre-computed in _setup as a Python int. - # Use const_expr if/else for compile-time conditional. - if const_expr(self.tOrP0_offset > 0): - tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) - else: - tOrP0 = tOrP - - tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) - pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) - - # ===== TMA LOAD warp ===== - if warp_idx == self.tma_warp_id: - if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd>256): use cutlass.range loop to avoid IR explosion - # from Python range unrolling. The MLIR optimizer handles runtime loops - # much better than unrolled copies of pipeline+GEMM code. - qp.reset() - kvp.reset() - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): - qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - kvh = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - # Load V[0] - kvh_v = kvp.acquire_and_advance() - cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier) - qp.tail() - kvp.tail() - else: - # Original pipeline path (hd≤256) - qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - qp.tail() - kvp.reset(); pk = kvp.try_acquire() - for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): - kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - pk = cutlass.Boolean(1) - kvp.tail() - - # ===== MMA warp ===== - if warp_idx == self.mma_warp_id: - tmem.wait_for_alloc() - if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd>256): cutlass.range loop (runtime, not unrolled) - qc.reset() - kvc.reset() - qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): - qh = qc.wait_and_advance(); qh.release() - kvh = kvc.wait_and_advance() - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - kvh.release() - # After all k_sub: S has full QK for this kt - cute.arch.fence_view_async_tmem_store() - softmax_done_bar.arrive() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, False) - # Load V: consume from K/V pipeline - kvh_v = kvc.wait_and_advance() - if not self.use_smem_p: - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh_v.release() - pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM - final_o_bar.arrive() - else: - # Original pipeline path (hd≤256) - qc.reset(); qh = qc.wait_and_advance(); qh.release() - kvc.reset(); pk = kvc.try_wait() - acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) - acc_pipe.producer_acquire(acc_st) - for kt in range(self.n_kv_tiles): - kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) - sh = s_prod.acquire_and_advance() - qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - sh.commit() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - if not self.use_smem_p: - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh.release() - pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM - acc_pipe.producer_commit(acc_st); acc_st.advance() - final_o_bar.arrive() - acc_pipe.producer_tail(acc_st) - - # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== - if warp_idx < self.mma_warp_id: - tmem.allocate(self.num_tmem_alloc_cols) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) - sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) - - # S load atoms - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) - thr_load = tiled_tmem_load.get_slice(sfw_idx) - tTMEM_LOADtS = thr_load.partition_S(tStS0) - cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) - tScS = qk_thr.partition_C(cS) - tTMEM_LOADcS = thr_load.partition_D(tScS) - - # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) - tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) - tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) - thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0) - tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tScP = cute.make_tensor(tScS.iterator, tScP_layout) - tTMEM_STOREcP = thr_store.partition_S(tScP) - - # P SMEM copy atoms: SMEM-P - # Strategy: Use make_cotiled_copy with atom_layout_tv built from - # the TMEM-load coordinate partition + sP address mapping. - # - # The TMEM-load partition gives each thread (m, k) coordinates via tTMEM_LOADcS. - # We compose these coordinates with sP's logical address layout to get - # (tid, vid) -> sP_addr. Then make_cotiled_copy creates a proper TiledCopy. - # - # Key: sP's outer layout maps (m, k0, k1, k2) -> sP_addr with strides (64, 1, 16, 8192). - # We need to build atom_layout_tv in sP's flat address space, not tStS's. - # - # Step 1: Build sP address mapping in the same coordinate system as tStS. - # sP is indexed as ((m, k%16), 0, ((k//16)%4, k//64)) with strides ((64,1),0,(16,8192)). - # In the P matrix's (m, k) coordinate space: - # sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64) - # This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192)) - _sP_nostage = sP[(None, None, None, 0)] # remove stage dim - - row_max = -Float32.inf - row_sum = Float32(0.0) - scale_log2 = Float32(self.scale_softmax_log2) - - # ============================================================ - # D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH - # ================================================= - # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: - # even NO-OP round-trip corrupts data (ratio = -11 billion). - # Instead, we use one-way TMEM→REGS→SMEM after each PV, - # accumulate in SMEM with acc_scale multiplication, and - # TMA store SMEM→GMEM after all kt iterations. - # - # For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store - # path works perfectly (cos=0.999998). The SMEM accumulator - # is only needed for n_kv_tiles > 1. - # ============================================================ - - # NOTE: The code below is the BROKEN TMEM round-trip approach. - # It's kept as reference but should NOT be used. - # The SMEM accumulator implementation is TODO. - - # prev_acc_scale: unused, kept for clarity. acc_scale at kt is used - # to rescale O from kt=0..kt-1 before PV[kt]. - prev_acc_scale = Float32(0.0) - - for kt in range(self.n_kv_tiles): - si_handle = s_cons.wait_and_advance() - - tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) - cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) - cute.arch.fence_view_async_tmem_load() - - # D3/D4/D5c: In-kernel logit modification. - # After loading S from TMEM, modify logits for SWA positions: - # D5c: Add sink_bias (attn_sink) to positions >= n_comp - # D3: Mask positions >= n_comp + swa_len to -inf - # D4: Causal mask — SWA positions where k_coord > m_coord → -inf - # Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col). - # For kt > 0, absolute KV pos = kt*128 + k_coord. - if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias): - kt_offset = Int32(kt * 128) # KV position offset for this tile - # D5c: Read sink bias once (same for all positions in this head). - # Define unconditionally for CuTeDSL scoping (used when apply_sink_bias). - # The bias must be added in the SCALED-LOG2 domain: attn_sink * log2(e). - # But we add to the RAW logits before the scale_log2 multiply. - # Raw correction: attn_sink / scale → after * scale_log2 → attn_sink * log2(e) - sink_val = Float32(0.0) - if const_expr(self.apply_sink_bias): - sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax) - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] # query row position - k_coord = coord[1] # position within this KV tile - kv_pos = kt_offset + k_coord # absolute KV position - # D5c: Add sink bias to SWA positions (>= n_comp) - if const_expr(self.apply_sink_bias): - if kv_pos >= Int32(self.n_comp): - tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val - # D3: SWA length mask - should_mask = Boolean(0) - if const_expr(self.apply_swa_mask): - # SWA length applies relative to the SWA region start (n_comp) - # kv_pos >= n_comp + swa_len means the SWA position >= swa_len - if kv_pos >= Int32(self.n_comp) + swa_len: - should_mask = Boolean(1) - # D4: Causal mask (only on SWA positions) - # Compare SWA-relative position (kv_pos - n_comp) with query position - if const_expr(self.is_causal): - if kv_pos >= Int32(self.n_comp): - swa_pos = kv_pos - Int32(self.n_comp) - if swa_pos > m_coord: - should_mask = Boolean(1) - if should_mask: - tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf - - old_row_max = row_max - frg_cnt = 4 - frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt - tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) - - row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - row_max_safe = Float32(0.0) - - acc_scale_ = old_row_max - row_max_safe - acc_scale = cute.math.exp2(acc_scale_, fastmath=True) - if old_row_max == -cutlass.Float32.inf: - acc_scale = Float32(0.0) - row_sum *= acc_scale - - rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) - rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) - minus_row_max = Float32(0.0) - row_max_safe - - rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max - tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) - row_sum = row_sum + tTMEM_LOADrS_frg[k, j] - s_vec = tTMEM_LOADrS_frg[None, j].load() - rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - - if not self.use_smem_p: - # TMEM-P: store P to TMEM via register bridge - cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) - cute.arch.fence_view_async_tmem_store() - else: - # SMEM-P: write P to sP using coordinate-indexed store. - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] - k_coord = coord[1] - k0 = k_coord % 16 - k1 = (k_coord // 16) % 4 - k2 = k_coord // 64 - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] - cute.arch.fence_proxy("async.shared", space="cta") - # D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED. - # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: - # even NO-OP round-trip corrupts O accumulator data. - # Production path for multi-KV-tile: Python KV merge (cos 0.999998). - # Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt). - # n_kv_tiles=1 is the only supported path for in-kernel processing. - - si_handle.release() - softmax_done_bar.arrive() - - # Wait for MMA's PV[N-1] to commit before reading O. - final_o_bar.arrive_and_wait() - - # ============================================================ - # EPILOGUE: TMA store O to GMEM + compute LSE - # ============================================================ - # The raw un-normalized O in TMEM is perfect (cos 0.999998). - # We use epilogue_tma_store which reads O from TMEM directly via - # the correct get_tmem_load_op layout — no round-trip needed. - # - # For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures - # O is correctly rescaled before this epilogue reads it. - # - # External normalization (D5a path): kernel outputs un-normalized O + - # LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum. - # This is exact and composes with D5c sink bias merge. - # ============================================================ - - # TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM) - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_cons_st = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) - acc_cons_st = utils.gemm.sm100.epilogue_tma_store( - self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, - 0, const_expr(lambda x: x), (0, 0, 0), - acc_cons_st, acc_pipe, c_pipe, - ) - c_pipe.producer_tail() - - # Compute LSE: lse = ln(row_sum) + row_max * ln(2) - # Only when emitting un-normalized output (D5a path). - # When normalize=True, LSE is not needed (in-kernel normalization). - # - # Per-row LSE: each softmax thread (sfw_idx 0..127) handles one row. - # sfw_idx maps directly to the row index in the attention matrix. - # All 128 threads write independently to mLSE[sfw_idx] — no sync needed. - if const_expr(not self.normalize): - _row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - _row_max_safe = Float32(0.0) - _ln2 = Float32(0.6931471805599453) # ln(2) - lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val - # Also output row_sum for external normalization (D5c) - mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum - - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) diff --git a/dsv4/kernels/attention/fmha_common.cuh b/dsv4/kernels/attention/fmha_common.cuh index e19e87c2..828e272a 100644 --- a/dsv4/kernels/attention/fmha_common.cuh +++ b/dsv4/kernels/attention/fmha_common.cuh @@ -1,9 +1,48 @@ /** - * DSV4 FMHA shared definitions — base header. - * BF16 type, TMEM ops, warp reductions, constants. + * DSV4 FMHA shared definitions — base header for raw CUDA kernels. * - * TMEM operations use uint32_t registers (b32), NOT float. - * Bitcast between float and uint32_t for FP32 TMEM values. + * ================================================================== + * WHY THIS EXISTS + * ================================================================== + * CuTeDSL (the Python DSL for CUTLASS) has fundamental limitations + * on Blackwell SM100 that make certain operations impossible: + * + * 1. TMEM round-trip is BROKEN (Ld32x32bOp/St32x32bOp column mismatch) + * 2. Float-to-int conversion is IMPOSSIBLE (arith.fptosi not lowerable) + * 3. epilogue_tma_store BLOCKS multi-CTA (can't accept flat_divide coords) + * 4. hd=512 MLIR backend HANGS (>3hr optimizer time) + * + * This header provides the building blocks for writing FMHA in raw + * CUDA C++ with inline PTX, bypassing ALL of the above. + * + * ================================================================== + * WHAT WORKS (tested on B200) + * ================================================================== + * - BF16 conversion via inline PTX cvt.rn.bf16.f32 / cvt.f32.bf16 + * - Warp reductions (fmax, sum) + * - TMEM alloc/dealloc via tcgen05 PTX + * - TMEM load/store via tcgen05.ld/st (uint32_t b32 registers) + * - TMEM fence via tcgen05.fence + * + * ================================================================== + * WHAT'S BROKEN / NEEDS WORK + * ================================================================== + * - TMEM load/store column addressing: the exact column offset + * calculation for row groups (8 row-groups per column) needs + * verification. The kernel using these ops hangs on B200. + * - tcgen05.mma (QK/PV GEMM): UMMA SMEM descriptor construction + * is placeholder only. The descriptor bitfield format is known + * (see cute/arch/mma_sm100_desc.hpp SmemDescriptor) but the + * exact values for our Q/K layouts haven't been validated. + * + * ================================================================== + * KEY INSIGHT FOR NVIDIA + * ================================================================== + * CuTeDSL's inability to lower float→int is a fundamental gap. + * Every quantization kernel needs f32→i32. The fact that nvvm.inline_ptx + * also fails suggests the CuTeDSL MLIR pipeline simply doesn't have a + * lowering path for ANY float→integer type conversion. This makes + * quantize-in-epilogue fusion impossible in CuTeDSL. */ #pragma once diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh index 8dfc157f..6c258b56 100644 --- a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -1,6 +1,52 @@ /** * DSV4 FMHA Phase 2 — TMEM accumulator + one-way correction epilogue. - * Uses uint32_t TMEM registers (matching CUTLASS PTX syntax). + * + * ================================================================== + * STATUS: BROKEN — kernel HANGS on B200 + * ================================================================== + * + * The concept is correct (the reference kernel proves the math), but the + * TMEM inline PTX operations cause the kernel to hang. Likely causes: + * + * 1. TMEM column addressing is wrong. The tcgen05.ld/st instructions + * take a single uint32_t column address. The exact mapping from + * (row_group, column) to the uint32_t address is unclear from the + * PTX ISA docs. The CUTLASS C++ code uses CuTe tensor abstractions + * that hide the raw addressing. + * + * 2. tcgen05.alloc may need a valid SMEM pointer that has enough + * backing storage. We're passing cvta.to.shared of the dynamic + * SMEM buffer, but the TMEM allocator may need a specific + * alignment or size. + * + * 3. The tcgen05.ld/st may need .pack::16b modifier for BF16 data, + * and the addressing is different for packed vs unpacked modes. + * + * ================================================================== + * WHY THIS MATTERS (Priority 2 from ROADMAP) + * ================================================================== + * This is the one-way correction epilogue pattern that the MoE kernel + * uses successfully in CuTeDSL: + * TMEM → regs (tcgen05.ld) → [normalize + BF16 cast] → GMEM + * + * If this works, it UNBLOCKS: + * - D2 multi-CTA grid (128 Python launches → 1 GPU launch) + * - NVFP4-1.2 (register slot for FP4 amax + pack in epilogue) + * - In-kernel normalize (O / row_sum without TMEM round-trip) + * - D1.5 fix (O rescale in REGISTERS between KV tiles) + * + * ================================================================== + * KEY INSIGHT FOR NVIDIA + * ================================================================== + * The tcgen05 PTX instructions are poorly documented for direct use. + * CUTLASS's CuTe tensor abstractions work but hide the raw addressing. + * CuTeDSL Python can use them via high-level APIs, but those APIs + * can't do float→int (see fmha_common.cuh). Raw CUDA needs the + * low-level PTX, but the column addressing is undocumented. + * + * Request: Document tcgen05.ld/st column addressing for raw PTX use, + * OR provide C-level intrinsics (like ___tmem_load, __tmem_store) + * that handle the addressing automatically. */ #pragma once #include "fmha_common.cuh" diff --git a/dsv4/kernels/attention/fmha_sm100.cuh b/dsv4/kernels/attention/fmha_sm100.cuh index f0cdc12f..6fc81d32 100644 --- a/dsv4/kernels/attention/fmha_sm100.cuh +++ b/dsv4/kernels/attention/fmha_sm100.cuh @@ -1,6 +1,40 @@ /** - * DSV4 FMHA Phase 1 Reference — scalar implementation. - * Uses SMEM for Q and O. Single-thread for correctness. + * DSV4 FMHA Phase 1 Reference — scalar implementation in raw CUDA C++. + * + * ================================================================== + * STATUS: WORKING (cos 0.999999 at hd=64, cos 0.999998 at hd=128) + * ================================================================== + * + * This is the CORRECT reference implementation. It proves that: + * - The online softmax with O rescale approach is mathematically correct + * - D3 SWA masking works + * - Raw CUDA C++ compiles and runs on Blackwell SM100 without CuTeDSL + * + * ================================================================== + * WHY RAW CUDA INSTEAD OF CUTEDSL + * ================================================================== + * CuTeDSL hit 4 fundamental walls on Blackwell: + * 1. TMEM round-trip broken (D1.5) — Ld32x32bOp/St32x32bOp mismatch + * 2. Float→int impossible — arith.fptosi not lowerable to PTX + * 3. epilogue_tma_store blocks multi-CTA + * 4. hd=512 MLIR optimizer hangs + * + * Writing in raw CUDA gives us full PTX control and bypasses all of these. + * This reference kernel took ~2 hours to get working. The equivalent + * CuTeDSL kernel took weeks and still has the D1.5 blocker. + * + * ================================================================== + * LIMITATIONS (intentional — correctness first, performance second) + * ================================================================== + * - Single-thread computation (tid==0 only) — SLOW but CORRECT + * - No TMEM or tensor cores — scalar math only + * - No D4 causal mask or D5c sink bias yet + * - No multi-KV-tile optimization + * + * These are all solvable incrementally. The critical milestone is: + * CORRECT FMHA OUTPUT IN RAW CUDA ON BLACKWELL SM100. + * + * Next phase: Parallelize across threads, add tcgen05.mma for QK/PV. */ #pragma once #include "fmha_common.cuh" diff --git a/dsv4/kernels/attention/fmha_sm100_launch.cu b/dsv4/kernels/attention/fmha_sm100_launch.cu index 7ce58e79..652358aa 100644 --- a/dsv4/kernels/attention/fmha_sm100_launch.cu +++ b/dsv4/kernels/attention/fmha_sm100_launch.cu @@ -1,5 +1,20 @@ /** * DSV4 FMHA Decode — Launch wrapper and PyTorch binding. + * + * ================================================================== + * STATUS: COMPILES but doesn't run via torch.utils.cpp_extension + * ================================================================== + * The kernel compiles cleanly with nvcc (see test_fmha_sm100.py), + * but torch JIT compilation fails due to __bf16 / bf16_t type + * conflicts with PyTorch's -D__CUDA_NO_BFLOAT16_CONVERSIONS__ flag. + * + * Workaround: Use the standalone test (test_fmha_sm100_standalone.cu) + * which compiles with nvcc directly and tests the kernel via CUDA + * runtime APIs (no PyTorch needed). + * + * To fix for production: Replace bf16_t with c10::BFloat16 and use + * AT_DISPATCH_FLOATING_TYPES for type dispatch. Or compile the .cu + * separately with nvcc and load as a shared library. */ #include "fmha_sm100.cuh" diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py deleted file mode 100644 index a0a4e30e..00000000 --- a/dsv4/kernels/attention/fmha_smem_acc.py +++ /dev/null @@ -1,592 +0,0 @@ -"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). - -Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path. -P stored to TMEM via register bridge, PV reads from TMEM. -O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration). -Normalization via final TMA store (SMEM→GMEM). -D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column -mapping mismatch). SMEM accumulator avoids round-trip entirely. -""" -import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline -from cutlass.cute.nvgpu import cpasync, tcgen05 -from cutlass import Float32, BFloat16, Int32, Boolean, const_expr -from cutlass.utils import LayoutEnum -from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset -from cutlass.utils.blackwell_helpers import get_smem_store_op -from cutlass.utils.gemm.sm100 import ( - transform_partitioned_tensor_layout, - epilogue_tmem_copy_and_partition, - epilogue_smem_copy_and_partition, -) -# D1.5: TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken. -# Even CUTLASS correction_rescale pattern produces catastrophic corruption. -# SMEM accumulator approach: one-way TMEM→REGS→SMEM per kt iteration. -import cuda.bindings.driver as cuda -import cutlass.torch as ct -import math - - -class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False): - # D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to - # positions >= n_comp. D3/D4 masks also only apply to SWA region. - # When n_comp is None or 0, no offset (backward compatible). - self.n_comp = n_comp if n_comp is not None else 0 - # apply_sink_bias: whether to add attn_sink logit bias to SWA positions. - # Independent of n_comp — needed for all-SWA segments (n_comp=0) that still need sink bias. - # When True, adds sink_bias to positions >= n_comp (which is 0 → all positions). - self.apply_sink_bias = apply_sink_bias - self.head_dim = head_dim - self.s_k = s_k - self.n_kv_tiles = s_k // 128 - self.pv_n_tile = min(head_dim, 256) - # At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB, - # making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512 - # (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256. - if head_dim > 256: - self.pv_n_tile = 128 - self.n_pv_tiles = head_dim // self.pv_n_tile - self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) - self.num_query_heads = num_query_heads - self.batch_size = batch_size - self.normalize = normalize # D5a: False = emit un-normalized O + lse - self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens - self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch - self.acc_dtype = Float32; self.qk_acc_dtype = Float32 - self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 - self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 - self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE - self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5 - self.threads_per_cta = 192 - # K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget - self.k_tile = min(head_dim, 256) - self.n_k_sub_tiles = head_dim // self.k_tile - self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd - self.q_stage = 1 - self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 - self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) - self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) - - def _setup(self, qk_mma, pv_mma): - qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) - # QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements. - # The tiler K must be head_dim so the QK loop iterates over all K sub-tiles. - self.qk_mma_tiler = (128, 128, self.k_tile) - pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) - self.mma_tiler = self.qk_mma_tiler - self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) - self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) - self.c_layout = LayoutEnum.ROW_MAJOR - self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) - self.num_ab_stage = 1; self.num_acc_stage = 1 - self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) - self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) - self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) - self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) - self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - # P SMEM layout (PV A-operand) — used for SMEM-P path - self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) - qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - self.tmem_s0_offset = 0 - if not self.use_smem_p: - # TMEM-P: S at 0, P at 32, O after P and S - self.tmem_p0_offset = 32 - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - p_end = self.tmem_p0_offset + p_cols_fp32 - s_cols = self.qk_mma_tiler[1] - o_after = max(s_cols, p_end) - self.tmem_o0_offset = ((o_after + 31) // 32) * 32 - o_cols = find_tmem_tensor_col_offset(tOtO) - total = self.tmem_o0_offset + o_cols - else: - # SMEM-P: P not in TMEM. S and O share TMEM (sequential). - self.tmem_p0_offset = -1 # unused - self.tmem_o0_offset = 0 - s_cols = self.qk_mma_tiler[1] - o_cols = find_tmem_tensor_col_offset(tOtO) - total = max(s_cols, o_cols) - self.num_tmem_alloc_cols = 1 - while self.num_tmem_alloc_cols < total: - self.num_tmem_alloc_cols *= 2 - # tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only) - # = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0 - self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int - cta = cute.size(qk_mma.thr_id.shape) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) - k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) - v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta - self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + - cute.size_in_bytes(self.q_dtype, v_s)) * cta - - @cute.jit - def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None): - self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype - self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() - self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() - v_fmha = cute.make_tensor( - v.iterator, - cute.make_layout( - (self.pv_n_tile, self.s_k, 1), - stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), - ), - ) - self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() - self.c_layout = LayoutEnum.from_tensor(c) - qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K - pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source) - self._setup(qk_mma, pv_mma) - q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) - tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) - tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) - epi_s = cute.select(self.c_smem_s,mode=[0,1]) - tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - # Always create a valid mLSE tensor for the kernel. - # CuTeDSL doesn't support None parameters in @cute.kernel. - if const_expr(lse is None): - lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) - if const_expr(swa_len is None): - # No SWA masking — pass max int (no positions masked) - swa_len = Int32(2147483647) - else: - swa_len = Int32(swa_len) - # D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions. - # When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the - # current head (n_h=1 in per-head launch mode). - if const_expr(sink_bias is None): - # D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory. - # Never actually read (const_expr(self.n_comp > 0) guards the read). - sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,))) - # else: sink_bias is already a CuTe tensor (caller must pass via ct.from_dlpack) - # Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension - # For single-head (n_h=1): grid=(1,1,1) — backward compatible - if const_expr(row_sums is None): - row_sums = cute.make_tensor(lse.iterator, lse.layout) - - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) - - @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums): - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - tidx,_,_ = cute.arch.thread_idx() - if warp_idx == self.tma_warp_id: - cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) - - @cute.struct - class SS: - q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] - kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] - s_bar: cute.struct.MemRange[cutlass.Int64, 2] - acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2] - tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 - smem = utils.SmemAllocator(); st = smem.allocate(SS) - - qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants() - s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() - softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) - final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) - # D1.5: pv_done_bar for SMEM accumulator approach. - # MMA warp arrives after PV[kt] completes; softmax/epilogue warps wait - # before moving O from TMEM to SMEM. - pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id)) - acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) - tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) - tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) - pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True) - - sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) - sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) - # sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB. - # TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput. - sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) - sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) - # sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM) - if const_expr(self.use_smem_p): - _p_layout = p_smem_s.outer - _p_swizzle = p_smem_s.inner - else: - _p_layout = cute.make_layout(((1,1),1,(1,1),1)) - _p_swizzle = cute.make_layout(((1,1),1,(1,1),1)) - sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle) - - gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) - gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) - gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) - gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) - n_kv_tiles = cute.size(gK, mode=[3]) - - qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0) - tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) - tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC) - a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape) - tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) - b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) - tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) - tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] - - tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) - tCrV = pv_mma.make_fragment_B(sV) - - qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) - pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) - - # PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally. - # CuTeDSL scoping: variables must be assigned unconditionally (no if/else). - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP) - tOrP = tOrP_base[(None,None,None,0)] - # tCrP is only used in SMEM-P path. Define unconditionally for CuTeDSL scoping. - tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP) - # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). - # self.tOrP0_offset is pre-computed in _setup as a Python int. - # Use const_expr if/else for compile-time conditional. - if const_expr(self.tOrP0_offset > 0): - tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) - else: - tOrP0 = tOrP - - tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) - pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) - - # ===== TMA LOAD warp ===== - if warp_idx == self.tma_warp_id: - if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd>256): use cutlass.range loop to avoid IR explosion - # from Python range unrolling. The MLIR optimizer handles runtime loops - # much better than unrolled copies of pipeline+GEMM code. - qp.reset() - kvp.reset() - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): - qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - kvh = kvp.acquire_and_advance() - cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - # Load V[0] - kvh_v = kvp.acquire_and_advance() - cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier) - qp.tail() - kvp.tail() - else: - # Original pipeline path (hd≤256) - qp.reset(); qh = qp.acquire_and_advance() - cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) - qp.tail() - kvp.reset(); pk = kvp.try_acquire() - for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1): - kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - pk = cutlass.Boolean(1) - kvp.tail() - - # ===== MMA warp ===== - if warp_idx == self.mma_warp_id: - tmem.wait_for_alloc() - if const_expr(self.n_k_sub_tiles > 1): - # K sub-tiling path (hd>256): cutlass.range loop (runtime, not unrolled) - qc.reset() - kvc.reset() - qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1): - qh = qc.wait_and_advance(); qh.release() - kvh = kvc.wait_and_advance() - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - kvh.release() - # After all k_sub: S has full QK for this kt - cute.arch.fence_view_async_tmem_store() - softmax_done_bar.arrive() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, False) - # Load V: consume from K/V pipeline - kvh_v = kvc.wait_and_advance() - if not self.use_smem_p: - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh_v.release() - pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM - final_o_bar.arrive() - else: - # Original pipeline path (hd≤256) - qc.reset(); qh = qc.wait_and_advance(); qh.release() - kvc.reset(); pk = kvc.try_wait() - acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) - acc_pipe.producer_acquire(acc_st) - for kt in range(self.n_kv_tiles): - kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) - sh = s_prod.acquire_and_advance() - qk_mma.set(tcgen05.Field.ACCUMULATE, False) - for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): - cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0) - qk_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - sh.commit() - softmax_done_bar.arrive_and_wait() - pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) - if not self.use_smem_p: - for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - else: - for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) - cute.arch.fence_view_async_tmem_store() - kvh.release() - pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM - acc_pipe.producer_commit(acc_st); acc_st.advance() - final_o_bar.arrive() - acc_pipe.producer_tail(acc_st) - - # ===== SOFTMAX + CORRECTION EPILOGUE warps ===== - if warp_idx < self.mma_warp_id: - tmem.allocate(self.num_tmem_alloc_cols) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) - sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) - - # S load atoms - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) - thr_load = tiled_tmem_load.get_slice(sfw_idx) - tTMEM_LOADtS = thr_load.partition_S(tStS0) - cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) - tScS = qk_thr.partition_C(cS) - tTMEM_LOADcS = thr_load.partition_D(tScS) - - # P store atoms: TMEM-P (always defined, only used when use_smem_p=False) - p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width - tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - # Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid) - tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout) - tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) - tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0) - thr_store = tiled_tmem_store.get_slice(sfw_idx) - tTMEM_STOREtP = thr_store.partition_D(tStP0) - tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32))) - tScP = cute.make_tensor(tScS.iterator, tScP_layout) - tTMEM_STOREcP = thr_store.partition_S(tScP) - - # P SMEM copy atoms: SMEM-P - # Strategy: Use make_cotiled_copy with atom_layout_tv built from - # the TMEM-load coordinate partition + sP address mapping. - # - # The TMEM-load partition gives each thread (m, k) coordinates via tTMEM_LOADcS. - # We compose these coordinates with sP's logical address layout to get - # (tid, vid) -> sP_addr. Then make_cotiled_copy creates a proper TiledCopy. - # - # Key: sP's outer layout maps (m, k0, k1, k2) -> sP_addr with strides (64, 1, 16, 8192). - # We need to build atom_layout_tv in sP's flat address space, not tStS's. - # - # Step 1: Build sP address mapping in the same coordinate system as tStS. - # sP is indexed as ((m, k%16), 0, ((k//16)%4, k//64)) with strides ((64,1),0,(16,8192)). - # In the P matrix's (m, k) coordinate space: - # sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64) - # This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192)) - _sP_nostage = sP[(None, None, None, 0)] # remove stage dim - - row_max = -Float32.inf - row_sum = Float32(0.0) - scale_log2 = Float32(self.scale_softmax_log2) - - # ============================================================ - # D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH - # ================================================= - # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: - # even NO-OP round-trip corrupts data (ratio = -11 billion). - # Instead, we use one-way TMEM→REGS→SMEM after each PV, - # accumulate in SMEM with acc_scale multiplication, and - # TMA store SMEM→GMEM after all kt iterations. - # - # For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store - # path works perfectly (cos=0.999998). The SMEM accumulator - # is only needed for n_kv_tiles > 1. - # ============================================================ - - # NOTE: The code below is the BROKEN TMEM round-trip approach. - # It's kept as reference but should NOT be used. - # The SMEM accumulator implementation is TODO. - - # prev_acc_scale: unused, kept for clarity. acc_scale at kt is used - # to rescale O from kt=0..kt-1 before PV[kt]. - prev_acc_scale = Float32(0.0) - - for kt in range(self.n_kv_tiles): - si_handle = s_cons.wait_and_advance() - - tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype) - cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) - cute.arch.fence_view_async_tmem_load() - - # D3/D4/D5c: In-kernel logit modification. - # After loading S from TMEM, modify logits for SWA positions: - # D5c: Add sink_bias (attn_sink) to positions >= n_comp - # D3: Mask positions >= n_comp + swa_len to -inf - # D4: Causal mask — SWA positions where k_coord > m_coord → -inf - # Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col). - # For kt > 0, absolute KV pos = kt*128 + k_coord. - if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias): - kt_offset = Int32(kt * 128) # KV position offset for this tile - # D5c: Read sink bias once (same for all positions in this head). - # Define unconditionally for CuTeDSL scoping (used when apply_sink_bias). - # The bias must be added in the SCALED-LOG2 domain: attn_sink * log2(e). - # But we add to the RAW logits before the scale_log2 multiply. - # Raw correction: attn_sink / scale → after * scale_log2 → attn_sink * log2(e) - sink_val = Float32(0.0) - if const_expr(self.apply_sink_bias): - sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax) - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] # query row position - k_coord = coord[1] # position within this KV tile - kv_pos = kt_offset + k_coord # absolute KV position - # D5c: Add sink bias to SWA positions (>= n_comp) - if const_expr(self.apply_sink_bias): - if kv_pos >= Int32(self.n_comp): - tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val - # D3: SWA length mask - should_mask = Boolean(0) - if const_expr(self.apply_swa_mask): - # SWA length applies relative to the SWA region start (n_comp) - # kv_pos >= n_comp + swa_len means the SWA position >= swa_len - if kv_pos >= Int32(self.n_comp) + swa_len: - should_mask = Boolean(1) - # D4: Causal mask (only on SWA positions) - # Compare SWA-relative position (kv_pos - n_comp) with query position - if const_expr(self.is_causal): - if kv_pos >= Int32(self.n_comp): - swa_pos = kv_pos - Int32(self.n_comp) - if swa_pos > m_coord: - should_mask = Boolean(1) - if should_mask: - tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf - - old_row_max = row_max - frg_cnt = 4 - frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt - tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) - - row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - row_max_safe = Float32(0.0) - - acc_scale_ = old_row_max - row_max_safe - acc_scale = cute.math.exp2(acc_scale_, fastmath=True) - if old_row_max == -cutlass.Float32.inf: - acc_scale = Float32(0.0) - row_sum *= acc_scale - - rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) - rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) - minus_row_max = Float32(0.0) - row_max_safe - - rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile)) - for j in range(frg_cnt): - for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max - tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) - row_sum = row_sum + tTMEM_LOADrS_frg[k, j] - s_vec = tTMEM_LOADrS_frg[None, j].load() - rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) - - if not self.use_smem_p: - # TMEM-P: store P to TMEM via register bridge - cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) - cute.arch.fence_view_async_tmem_store() - else: - # SMEM-P: write P to sP using coordinate-indexed store. - for j0 in range(32): - for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] - k_coord = coord[1] - k0 = k_coord % 16 - k1 = (k_coord // 16) % 4 - k2 = k_coord // 64 - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] - cute.arch.fence_proxy("async.shared", space="cta") - # D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED. - # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: - # even NO-OP round-trip corrupts O accumulator data. - # Production path for multi-KV-tile: Python KV merge (cos 0.999998). - # Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt). - # n_kv_tiles=1 is the only supported path for in-kernel processing. - - si_handle.release() - softmax_done_bar.arrive() - - # Wait for MMA's PV[N-1] to commit before reading O. - final_o_bar.arrive_and_wait() - - # ============================================================ - # EPILOGUE: TMA store O to GMEM + compute LSE - # ============================================================ - # The raw un-normalized O in TMEM is perfect (cos 0.999998). - # We use epilogue_tma_store which reads O from TMEM directly via - # the correct get_tmem_load_op layout — no round-trip needed. - # - # For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures - # O is correctly rescaled before this epilogue reads it. - # - # External normalization (D5a path): kernel outputs un-normalized O + - # LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum. - # This is exact and composes with D5c sink bias merge. - # ============================================================ - - # TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM) - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) - c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) - acc_cons_st = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage - ) - acc_cons_st = utils.gemm.sm100.epilogue_tma_store( - self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, - 0, const_expr(lambda x: x), (0, 0, 0), - acc_cons_st, acc_pipe, c_pipe, - ) - c_pipe.producer_tail() - - # Compute LSE: lse = ln(row_sum) + row_max * ln(2) - # Only when emitting un-normalized output (D5a path). - # When normalize=True, LSE is not needed (in-kernel normalization). - # - # Per-row LSE: each softmax thread (sfw_idx 0..127) handles one row. - # sfw_idx maps directly to the row index in the attention matrix. - # All 128 threads write independently to mLSE[sfw_idx] — no sync needed. - if const_expr(not self.normalize): - _row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - _row_max_safe = Float32(0.0) - _ln2 = Float32(0.6931471805599453) # ln(2) - lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val - # Also output row_sum for external normalization (D5c) - mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum - - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 64c14d3e..cb395492 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -1,5 +1,29 @@ """DSV4 Blackwell Attention — Production kernel wrapper. +==================================================================== +STATUS: WORKING for single-tile, Python KV merge for multi-tile +==================================================================== + +See ROADMAP.md Priority 5 (Stage E) for what's needed to ship. +Key gaps: custom_op registration, kernel cache warmup, batch fusion. + +==================================================================== +WHAT WORKS +==================================================================== +- Per-KV-group head-packed launch (MQA/GQA efficient) +- Python KV merge for multi-KV-tile (cos 0.999998) +- D3/D4/D5c masks +- Batch via Python outer loop + +==================================================================== +WHAT'S BLOCKED +==================================================================== +- In-kernel multi-KV-tile: blocked on D1.5 (TMEM round-trip broken) +- Batch fusion into grid: blocked on D2 (multi-CTA, epilogue_tma_store) +- hd > 256: CuTeDSL MLIR hang (>3hr optimizer time) + +==================================================================== + Wraps the CuTeDSL FMHA kernel with Python KV merge for multi-KV-tile. Supports MHA, MQA, and GQA attention patterns with head-packed launches for efficient MQA/GQA (all Q heads sharing a KV head dispatched in one