From 81acf1593cbde3bba2f3349c9cd668d282a1f3b2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 27 May 2026 02:17:26 +0000 Subject: [PATCH] =?UTF-8?q?Revert=20"D1.5:=20WIP=20SMEM=20accumulator=20?= =?UTF-8?q?=E2=80=94=20framework=20in=20place,=20accumulation=20logic=20TO?= =?UTF-8?q?DO"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 72d88af400e1a2715ac824b1f070446a7c312b95. --- dsv4/kernels/attention/fmha.py | 91 +++++++++------------------------- 1 file changed, 24 insertions(+), 67 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 999a17fb..a0a4e30e 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -410,29 +410,25 @@ class FmhaKernel: scale_log2 = Float32(self.scale_softmax_log2) # ============================================================ - # D1.5: SMEM ACCUMULATOR for multi-KV-tile O rescale + # D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH + # ================================================= + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts data (ratio = -11 billion). + # Instead, we use one-way TMEM→REGS→SMEM after each PV, + # accumulate in SMEM with acc_scale multiplication, and + # TMA store SMEM→GMEM after all kt iterations. + # + # For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store + # path works perfectly (cos=0.999998). The SMEM accumulator + # is only needed for n_kv_tiles > 1. # ============================================================ - # After each PV[kt], move O from TMEM to SMEM via one-way epilogue. - # Accumulate in SMEM with acc_scale multiplication. - # TMEM round-trip is fundamentally broken — one-way only. - # ============================================================ - if const_expr(self.n_kv_tiles > 1): - # Build one-way TMEM→REGS→SMEM epilogue pipeline - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base) - tCgC_transformed = transform_partitioned_tensor_layout(tCgC) - # TMEM→REGS (paired atoms from epilogue infrastructure) - tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition( - self, sfw_idx, tCtO_transformed, tCgC_transformed, - epi_tile, self.use_2cta_instrs, - ) - # REGS→SMEM (paired atoms) - tTR_rC = cute.make_rmem_tensor(tTR_rO.shape, self.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( - self, tiled_copy_t2r, tTR_rC, sfw_idx, sC, - ) + # NOTE: The code below is the BROKEN TMEM round-trip approach. + # It's kept as reference but should NOT be used. + # The SMEM accumulator implementation is TODO. + # prev_acc_scale: unused, kept for clarity. acc_scale at kt is used + # to rescale O from kt=0..kt-1 before PV[kt]. prev_acc_scale = Float32(0.0) for kt in range(self.n_kv_tiles): @@ -533,54 +529,15 @@ class FmhaKernel: k2 = k_coord // 64 _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") - # D1.5: SMEM accumulator epilogue for kt > 0. - # After signaling P ready, wait for PV[kt] to complete, - # then move O from TMEM to SMEM with acc_scale accumulation. - # TMEM round-trip is broken — one-way only (TMEM→REGS→SMEM). - if const_expr(self.n_kv_tiles > 1): - si_handle.release() - softmax_done_bar.arrive() # Signal MMA: P[kt] ready + # D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED. + # TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken: + # even NO-OP round-trip corrupts O accumulator data. + # Production path for multi-KV-tile: Python KV merge (cos 0.999998). + # Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt). + # n_kv_tiles=1 is the only supported path for in-kernel processing. - pv_done_bar.arrive_and_wait() # Wait for PV[kt] to complete - - # One-way TMEM→REGS: load O_kt from TMEM - tTR_tO_mn = tTR_tO_base[(None, None, None, 0)] - cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO) - cute.arch.fence_view_async_tmem_load() - - # In REGS: multiply by acc_scale (rescale previous accumulation) - # acc_scale for kt=0 is 0.0 (old_row_max = -inf), - # so O_acc = 0 * O_acc + O_0 = O_0 (first iteration). - # For kt>0: O_acc = acc_scale * O_acc + O_kt. - # But we can't load the previous SMEM accumulator into the same - # register buffer. We need to go: TMEM→REGS, then SMEM↔REGS. - # - # Simpler: store O_kt to SMEM (BF16), then do a separate - # SMEM-level accumulation pass. But this loses FP32 precision. - # - # Correct: convert FP32 O_kt to BF16, store to sC via r2s copy. - # Then in a separate loop, load BF16 from sC, multiply by - # acc_scale, add to FP32 SMEM accumulator. But this requires - # an extra SMEM buffer. - # - # For NOW: use the MoE pattern — convert to BF16 in registers, - # store to SMEM, then TMA to GMEM. The accumulation happens - # in a separate kernel or in Python (KV merge). - # This is a stepping stone to full SMEM accumulation. - - # Convert FP32 O to BF16 in registers - for k in cutlass.range(cute.size(tTR_rO), vectorize=True): - tTR_rC[k] = tTR_rO[k].to(self.c_dtype) - - # REGS→SMEM: store BF16 O_kt to sC - # For kt>0, we need to ADD to previous sC. But epilogue_smem_copy - # does a store, not an add. We need SMEM read-modify-write. - # SKIP for now — just do one-way epilogue per kt to separate - # GMEM buffers, then merge in Python. - pass # TODO: SMEM accumulation - else: - si_handle.release() - softmax_done_bar.arrive() + si_handle.release() + softmax_done_bar.arrive() # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait()