From afb93eae22adfab0a43d33b3060b3eb71e110962 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 20:55:16 +0000 Subject: [PATCH] D1.5: Revert broken TMEM round-trip O rescale, document as fundamentally broken MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TMEM round-trip via Ld32x32bOp/St32x32bOp corrupts O accumulator data even with CUTLASS correction_rescale pattern. All variants tested: - Repetition(16) + composition (CUTLASS exact pattern) — BROKEN - Repetition(32) + composition — BROKEN - Repetition(16) raw layout (no composition) — BROKEN Even NO-OP (multiply by 1.0) produces catastrophically wrong results. Production path remains Python KV merge (cos 0.999998 for s_k up to 1024). Next: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt). --- dsv4/kernels/attention/fmha.py | 116 ++++-------------- dsv4/kernels/attention/fmha_smem_acc.py | 151 ++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 91 deletions(-) create mode 100644 dsv4/kernels/attention/fmha_smem_acc.py diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 761fdaeb..b48f5c8f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -21,7 +21,7 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False, debug_noop_rescale=False): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False): # D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to # positions >= n_comp. D3/D4 masks also only apply to SWA region. # When n_comp is None or 0, no offset (backward compatible). @@ -58,8 +58,6 @@ class FmhaKernel: self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd self.q_stage = 1 self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512 - self.debug_noop_rescale = debug_noop_rescale # D1.5 debug: force acc_scale=1.0 in O rescale - self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) @@ -189,9 +187,8 @@ class FmhaKernel: s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants() softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id)) final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id)) - # D1.5: barrier for PV completion signal (MMA→softmax warps) - # MMA warp arrives after PV[kt] completes; softmax warps wait before O rescale. - pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id)) + # D1.5: pv_done_bar for O rescale (currently unused — TMEM round-trip broken) + # pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id)) acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True) tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) @@ -319,8 +316,7 @@ class FmhaKernel: pv_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store() kvh_v.release() - if const_expr(self.n_kv_tiles > 1): - pv_done_bar.arrive() # Signal softmax warps: PV done, O is ready for rescale + # pv_done_bar.arrive() # D1.5: unused — TMEM round-trip broken final_o_bar.arrive() else: # Original pipeline path (hd≤256) @@ -338,8 +334,6 @@ class FmhaKernel: cute.arch.fence_view_async_tmem_store() sh.commit() softmax_done_bar.arrive_and_wait() - if const_expr(self.n_kv_tiles > 1): - cute.arch.fence_view_async_tmem_load() # Ensure rescaled O visible before PV[kt] pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) if not self.use_smem_p: for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): @@ -351,8 +345,7 @@ class FmhaKernel: pv_mma.set(tcgen05.Field.ACCUMULATE, True) cute.arch.fence_view_async_tmem_store() kvh.release() - if const_expr(self.n_kv_tiles > 1): - pv_done_bar.arrive() # Signal softmax warps: PV done, O ready for rescale + # pv_done_bar.arrive() # D1.5: unused — TMEM round-trip broken acc_pipe.producer_commit(acc_st); acc_st.advance() final_o_bar.arrive() acc_pipe.producer_tail(acc_st) @@ -409,53 +402,22 @@ class FmhaKernel: scale_log2 = Float32(self.scale_softmax_log2) # ============================================================ - # D1.5: O RESCALE ATOMS (CUTLASS correction_rescale pattern) + # D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH + # ================================================= + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts data (ratio = -11 billion). + # Instead, we use one-way TMEM→REGS→SMEM after each PV, + # accumulate in SMEM with acc_scale multiplication, and + # TMA store SMEM→GMEM after all kt iterations. + # + # For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store + # path works perfectly (cos=0.999998). The SMEM accumulator + # is only needed for n_kv_tiles > 1. # ============================================================ - # Pattern: both load and store atoms built from the SAME tOtO_i - # (composition-tiled from tOtO0), same Repetition(corr_tile_size). - # This is the exact pattern from CUTLASS reference fmha.py line 2123. - # The key insight: using composition() to re-tile tOtO into (128, corr_tile_size) - # sub-tiles, and building BOTH copies from the SAME tensor, ensures the - # column mappings agree on round-trip. - # ============================================================ - corr_tile_size = 16 # Must be power of 2, divides head_dim - # Try both composition and raw layout - use_comp = True - if const_expr(use_comp): - tOtO_i_layout = cute.composition( - tOtO0.layout, cute.make_layout((128, corr_tile_size)) - ) - else: - tOtO_i_layout = tOtO0.layout - tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) - # Coordinate tensor for O (needed for partition_D of load) - cO = cute.make_identity_tensor((128, self.head_dim)) - tOcO = pv_thr.partition_C(cO) - if const_expr(use_comp): - tOcO_i_layout = cute.composition( - tOcO.layout, cute.make_layout((128, corr_tile_size)) - ) - else: - tOcO_i_layout = tOcO.layout - tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) - - tmem_load_o_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.qk_acc_dtype, - ) - tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) - thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) - tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) - tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) - - tmem_store_o_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), - self.qk_acc_dtype, - ) - tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) - thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) - tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) + # NOTE: The code below is the BROKEN TMEM round-trip approach. + # It's kept as reference but should NOT be used. + # The SMEM accumulator implementation is TODO. # prev_acc_scale: unused, kept for clarity. acc_scale at kt is used # to rescale O from kt=0..kt-1 before PV[kt]. @@ -559,40 +521,12 @@ class FmhaKernel: k2 = k_coord // 64 _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") - # D1.5: O rescale for kt > 0 — CUTLASS correction_rescale pattern. - # After computing acc_scale for this iteration, rescale the existing O - # in TMEM before the next PV GEMM adds to it. - # Must wait for PV[kt-1] to complete (MMA signals pv_done_bar). - if const_expr(self.n_kv_tiles > 1): - if kt > 0: - pv_done_bar.arrive_and_wait() # Wait for PV[kt-1] - # Rescale O: load, multiply by acc_scale, store back to TMEM. - # CUTLASS pattern: both copies use same tOtO_i (composition-tiled). - rescale_factor = acc_scale - if const_expr(self.debug_noop_rescale): - rescale_factor = Float32(1.0) - n_slices = self.head_dim // corr_tile_size - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, n_slices), self.qk_acc_dtype - ) - for i in range(n_slices): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOADtO_i = cute.make_tensor( - tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout - ) - tTMEM_STOREtO_i = cute.make_tensor( - tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - cute.arch.fence_view_async_tmem_load() - for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[k] = tTMrO_i[k] * rescale_factor - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) - cute.arch.fence_view_async_tmem_store() + # D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED. + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts O accumulator data. + # Production path for multi-KV-tile: Python KV merge (cos 0.999998). + # Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt). + # n_kv_tiles=1 is the only supported path for in-kernel processing. si_handle.release() softmax_done_bar.arrive() diff --git a/dsv4/kernels/attention/fmha_smem_acc.py b/dsv4/kernels/attention/fmha_smem_acc.py new file mode 100644 index 00000000..a899d98d --- /dev/null +++ b/dsv4/kernels/attention/fmha_smem_acc.py @@ -0,0 +1,151 @@ +"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100). + +SMEM accumulator approach for multi-KV-tile O rescale. +Instead of TMEM round-trip (which corrupts data), we move O from TMEM +to SMEM after each PV GEMM via one-way epilogue, and accumulate in SMEM. + +This avoids the D1.5 TMEM round-trip bug entirely. + +Architecture: +- 6-warp specialization: 4 softmax+epilogue, 1 MMA, 1 TMA +- After PV[kt]: one-way TMEM→REGS→SMEM with acc_scale multiplication +- SMEM accumulator persists across kt iterations +- Final TMA store: SMEM→GMEM + +Per-kt flow: +1. Softmax warps: compute P[kt], acc_scale[kt] +2. Signal softmax_done_bar +3. MMA warp: PV[kt] GEMM (ACCUMULATE=False, fresh TMEM) +4. Signal pv_done_bar +5. Softmax/epilogue warps: TMEM→REGS, acc_scale*O_acc + O_kt, REGS→SMEM +6. Repeat for next kt +7. After all kt: SMEM→GMEM via TMA +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +from cutlass.utils.blackwell_helpers import get_smem_store_op +from cutlass.utils.gemm.sm100 import ( + transform_partitioned_tensor_layout, + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, +) +import cuda.bindings.driver as cuda +import cutlass.torch as ct +import math + + +class FmhaKernel: + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, + num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, + n_comp=None, apply_sink_bias=False): + self.n_comp = n_comp if n_comp is not None else 0 + self.apply_sink_bias = apply_sink_bias + self.head_dim = head_dim + self.s_k = s_k + self.n_kv_tiles = s_k // 128 + self.pv_n_tile = min(head_dim, 256) + if head_dim > 256: + self.pv_n_tile = 128 + self.n_pv_tiles = head_dim // self.pv_n_tile + self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) + self.num_query_heads = num_query_heads + self.batch_size = batch_size + self.normalize = normalize + self.apply_swa_mask = apply_swa_mask + self.is_causal = is_causal + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) + self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + + def _setup(self, qk_mma, pv_mma): + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (128, 128, self.k_tile) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) + self.mma_tiler = self.qk_mma_tiler + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) + self.c_layout = LayoutEnum.ROW_MAJOR + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) + self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + self.tmem_s0_offset = 0 + if not self.use_smem_p: + self.tmem_p0_offset = 32 + p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width + p_end = self.tmem_p0_offset + p_cols_fp32 + s_cols = self.qk_mma_tiler[1] + o_after = max(s_cols, p_end) + self.tmem_o0_offset = ((o_after + 31) // 32) * 32 + o_cols = find_tmem_tensor_col_offset(tOtO) + total = self.tmem_o0_offset + o_cols + else: + self.tmem_p0_offset = -1 + self.tmem_o0_offset = 0 + s_cols = self.qk_mma_tiler[1] + o_cols = find_tmem_tensor_col_offset(tOtO) + total = max(s_cols, o_cols) + self.num_tmem_alloc_cols = 1 + while self.num_tmem_alloc_cols < total: + self.num_tmem_alloc_cols *= 2 + self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 + cta = cute.size(qk_mma.thr_id.shape) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta + self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + + cute.size_in_bytes(self.q_dtype, v_s)) * cta + + @cute.jit + def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None): + self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + v_fmha = cute.make_tensor( + v.iterator, + cute.make_layout( + (self.pv_n_tile, self.s_k, 1), + stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), + ), + ) + self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM) + pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K + pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,self.pv_n_tile), pv_source) + self._setup(qk_mma, pv_mma) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) + tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) + tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) + epi_s = cute.select(self.c_smem_s,mode=[0,1]) + tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) + if const_expr(lse is None): + lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) + if const_expr(swa_len is None): + swa_len = Int32(2147483647) + else: + swa_len = Int32(swa_len) + if const_expr(sink_bias is None): + sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,))) + if const_expr(row_sums is None): + row_sums = cute.make_tensor(lse.iterator, lse.layout) + + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) + + # ... rest of kernel to be implemented