D1.5: Revert broken paired-atom O rescale — TMEM round-trip fundamentally broken

Ld32x32bOp and St32x32bOp have different column mappings at the hardware
level. No layout transformation can fix this — the atoms themselves map
TMEM columns differently.

The MoE correction epilogue avoids the problem by doing a ONE-WAY trip
(TMEM→REGS→SMEM→GMEM, never writes back to TMEM). FMHA needs O in TMEM
for PV accumulation between kt iterations, so one-way doesn't help.

Production path for multi-KV-tile: Python KV merge (already verified,
cos 0.999998 for s_k up to 1024). Run kernel per 128-token segment.

Future: restructure PV to accumulate into REGS/SMEM instead of TMEM,
enabling the one-way correction epilogue pattern.
This commit is contained in:
2026-05-26 19:50:11 +00:00
parent 40cbf0c223
commit ffb3e736bb

View File

@@ -11,10 +11,9 @@ from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.blackwell_helpers import get_smem_store_op
from cutlass.utils.gemm.sm100 import (
transform_partitioned_tensor_layout,
epilogue_tmem_copy_and_partition,
)
# TMEM round-trip is fundamentally broken (Ld32x32bOp/St32x32bOp column mapping mismatch).
# The one-way correction epilogue pattern (from cutlass.utils.gemm.sm100) requires
# restructuring PV to not use TMEM accumulator. See D1.5 notes in STAGE_D.md.
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
@@ -398,24 +397,22 @@ class FmhaKernel:
scale_log2 = Float32(self.scale_softmax_log2)
# ============================================================
# O RESCALE PAIRED ATOMS (D1.5 fix, multi-KV-tile only)
# D1.5: MULTI-KV-TILE O RESCALE — NOT SUPPORTED IN-KERNEL
# ============================================================
# Replace broken hand-constructed Ld32x32bOp/St32x32bOp round-trip
# with paired atoms from epilogue_tmem_copy_and_partition.
# The paired atoms share addressing, so the TMEM→REGS→modify→TMEM
# cycle is lossless (unlike independently constructed atoms).
# Only needed when n_kv_tiles > 1 (multi-KV-tile O rescale).
# TMEM round-trip (load O, modify, store back) is FUNDAMENTALLY
# broken: Ld32x32bOp and St32x32bOp have different column mappings
# at the hardware level. The MoE correction epilogue avoids this
# by doing a ONE-WAY trip (TMEM->REGS->SMEM->GMEM), but FMHA needs
# to keep O in TMEM for PV accumulation between kt iterations.
#
# Production path for multi-KV-tile: Python KV merge.
# Run kernel per 128-token segment (s_k=128), merge externally:
# O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]
# Verified cos 0.999998 for s_k up to 1024.
#
# Future: restructure PV to accumulate into REGS/SMEM instead
# of TMEM, enabling the one-way correction epilogue pattern.
# ============================================================
if const_expr(self.n_kv_tiles > 1):
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base)
tCgC_transformed = transform_partitioned_tensor_layout(tCgC)
tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition(
self, sfw_idx, tCtO_transformed, tCgC_transformed,
epi_tile, self.use_2cta_instrs,
)
tTR_tO_grouped = cute.group_modes(tTR_tO_base, 3, cute.rank(tTR_tO_base))
subtile_cnt = cute.size(tTR_tO_grouped.shape, mode=[3])
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
@@ -515,23 +512,9 @@ class FmhaKernel:
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# O rescale for kt > 0 using paired atoms (D1.5 fix).
# TMEM→REGS (paired load), multiply by acc_scale,
# REGS→TMEM (paired store via retile_to_S).
# The paired atom's addressing is consistent for load and store,
# so this does NOT suffer from the layout mismatch that broke the
# hand-constructed Ld32x32bOp/St32x32bOp round-trip.
if const_expr(self.n_kv_tiles > 1):
if kt > 0:
for subtile_idx in range(subtile_cnt):
tTR_tO_mn = tTR_tO_grouped[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)
# Modify in registers
for k in cutlass.range(cute.size(tTR_rO), vectorize=True):
tTR_rO[k] = tTR_rO[k] * acc_scale
# Store back to TMEM via paired atom's store direction
cute.copy(tiled_copy_t2r.retile_to_S(), tTR_rO, tTR_tO_mn)
cute.arch.fence_view_async_tmem_store()
# D1.5: O rescale for kt > 0 is NOT supported in-kernel.
# Multi-KV-tile attention uses Python KV merge instead.
# n_kv_tiles=1 is the only tested/supported path.
si_handle.release()
softmax_done_bar.arrive()