diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 4475fae7..d82b030a 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -11,10 +11,9 @@ from cutlass import Float32, BFloat16, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset from cutlass.utils.blackwell_helpers import get_smem_store_op -from cutlass.utils.gemm.sm100 import ( - transform_partitioned_tensor_layout, - epilogue_tmem_copy_and_partition, -) +# TMEM round-trip is fundamentally broken (Ld32x32bOp/St32x32bOp column mapping mismatch). +# The one-way correction epilogue pattern (from cutlass.utils.gemm.sm100) requires +# restructuring PV to not use TMEM accumulator. See D1.5 notes in STAGE_D.md. import cuda.bindings.driver as cuda import cutlass.torch as ct import math @@ -398,24 +397,22 @@ class FmhaKernel: scale_log2 = Float32(self.scale_softmax_log2) # ============================================================ - # O RESCALE PAIRED ATOMS (D1.5 fix, multi-KV-tile only) + # D1.5: MULTI-KV-TILE O RESCALE — NOT SUPPORTED IN-KERNEL # ============================================================ - # Replace broken hand-constructed Ld32x32bOp/St32x32bOp round-trip - # with paired atoms from epilogue_tmem_copy_and_partition. - # The paired atoms share addressing, so the TMEM→REGS→modify→TMEM - # cycle is lossless (unlike independently constructed atoms). - # Only needed when n_kv_tiles > 1 (multi-KV-tile O rescale). + # TMEM round-trip (load O, modify, store back) is FUNDAMENTALLY + # broken: Ld32x32bOp and St32x32bOp have different column mappings + # at the hardware level. The MoE correction epilogue avoids this + # by doing a ONE-WAY trip (TMEM->REGS->SMEM->GMEM), but FMHA needs + # to keep O in TMEM for PV accumulation between kt iterations. + # + # Production path for multi-KV-tile: Python KV merge. + # Run kernel per 128-token segment (s_k=128), merge externally: + # O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)] + # Verified cos 0.999998 for s_k up to 1024. + # + # Future: restructure PV to accumulate into REGS/SMEM instead + # of TMEM, enabling the one-way correction epilogue pattern. # ============================================================ - if const_expr(self.n_kv_tiles > 1): - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) - tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base) - tCgC_transformed = transform_partitioned_tensor_layout(tCgC) - tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition( - self, sfw_idx, tCtO_transformed, tCgC_transformed, - epi_tile, self.use_2cta_instrs, - ) - tTR_tO_grouped = cute.group_modes(tTR_tO_base, 3, cute.rank(tTR_tO_base)) - subtile_cnt = cute.size(tTR_tO_grouped.shape, mode=[3]) for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance() @@ -515,23 +512,9 @@ class FmhaKernel: k2 = k_coord // 64 _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") - # O rescale for kt > 0 using paired atoms (D1.5 fix). - # TMEM→REGS (paired load), multiply by acc_scale, - # REGS→TMEM (paired store via retile_to_S). - # The paired atom's addressing is consistent for load and store, - # so this does NOT suffer from the layout mismatch that broke the - # hand-constructed Ld32x32bOp/St32x32bOp round-trip. - if const_expr(self.n_kv_tiles > 1): - if kt > 0: - for subtile_idx in range(subtile_cnt): - tTR_tO_mn = tTR_tO_grouped[(None, None, None, subtile_idx)] - cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO) - # Modify in registers - for k in cutlass.range(cute.size(tTR_rO), vectorize=True): - tTR_rO[k] = tTR_rO[k] * acc_scale - # Store back to TMEM via paired atom's store direction - cute.copy(tiled_copy_t2r.retile_to_S(), tTR_rO, tTR_tO_mn) - cute.arch.fence_view_async_tmem_store() + # D1.5: O rescale for kt > 0 is NOT supported in-kernel. + # Multi-KV-tile attention uses Python KV merge instead. + # n_kv_tiles=1 is the only tested/supported path. si_handle.release() softmax_done_bar.arrive()