Revert "D1.5: WIP SMEM accumulator — framework in place, accumulation logic TODO"

This reverts commit 72d88af400.
This commit is contained in:
2026-05-27 02:17:26 +00:00
parent 72d88af400
commit 81acf1593c

View File

@@ -410,29 +410,25 @@ class FmhaKernel:
scale_log2 = Float32(self.scale_softmax_log2)
# ============================================================
# D1.5: SMEM ACCUMULATOR for multi-KV-tile O rescale
# D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH
# =================================================
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts data (ratio = -11 billion).
# Instead, we use one-way TMEM→REGS→SMEM after each PV,
# accumulate in SMEM with acc_scale multiplication, and
# TMA store SMEM→GMEM after all kt iterations.
#
# For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store
# path works perfectly (cos=0.999998). The SMEM accumulator
# is only needed for n_kv_tiles > 1.
# ============================================================
# After each PV[kt], move O from TMEM to SMEM via one-way epilogue.
# Accumulate in SMEM with acc_scale multiplication.
# TMEM round-trip is fundamentally broken — one-way only.
# ============================================================
if const_expr(self.n_kv_tiles > 1):
# Build one-way TMEM→REGS→SMEM epilogue pipeline
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base)
tCgC_transformed = transform_partitioned_tensor_layout(tCgC)
# TMEM→REGS (paired atoms from epilogue infrastructure)
tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition(
self, sfw_idx, tCtO_transformed, tCgC_transformed,
epi_tile, self.use_2cta_instrs,
)
# REGS→SMEM (paired atoms)
tTR_rC = cute.make_rmem_tensor(tTR_rO.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC,
)
# NOTE: The code below is the BROKEN TMEM round-trip approach.
# It's kept as reference but should NOT be used.
# The SMEM accumulator implementation is TODO.
# prev_acc_scale: unused, kept for clarity. acc_scale at kt is used
# to rescale O from kt=0..kt-1 before PV[kt].
prev_acc_scale = Float32(0.0)
for kt in range(self.n_kv_tiles):
@@ -533,54 +529,15 @@ class FmhaKernel:
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# D1.5: SMEM accumulator epilogue for kt > 0.
# After signaling P ready, wait for PV[kt] to complete,
# then move O from TMEM to SMEM with acc_scale accumulation.
# TMEM round-trip is broken — one-way only (TMEM→REGS→SMEM).
if const_expr(self.n_kv_tiles > 1):
si_handle.release()
softmax_done_bar.arrive() # Signal MMA: P[kt] ready
# D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED.
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts O accumulator data.
# Production path for multi-KV-tile: Python KV merge (cos 0.999998).
# Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
# n_kv_tiles=1 is the only supported path for in-kernel processing.
pv_done_bar.arrive_and_wait() # Wait for PV[kt] to complete
# One-way TMEM→REGS: load O_kt from TMEM
tTR_tO_mn = tTR_tO_base[(None, None, None, 0)]
cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)
cute.arch.fence_view_async_tmem_load()
# In REGS: multiply by acc_scale (rescale previous accumulation)
# acc_scale for kt=0 is 0.0 (old_row_max = -inf),
# so O_acc = 0 * O_acc + O_0 = O_0 (first iteration).
# For kt>0: O_acc = acc_scale * O_acc + O_kt.
# But we can't load the previous SMEM accumulator into the same
# register buffer. We need to go: TMEM→REGS, then SMEM↔REGS.
#
# Simpler: store O_kt to SMEM (BF16), then do a separate
# SMEM-level accumulation pass. But this loses FP32 precision.
#
# Correct: convert FP32 O_kt to BF16, store to sC via r2s copy.
# Then in a separate loop, load BF16 from sC, multiply by
# acc_scale, add to FP32 SMEM accumulator. But this requires
# an extra SMEM buffer.
#
# For NOW: use the MoE pattern — convert to BF16 in registers,
# store to SMEM, then TMA to GMEM. The accumulation happens
# in a separate kernel or in Python (KV merge).
# This is a stepping stone to full SMEM accumulation.
# Convert FP32 O to BF16 in registers
for k in cutlass.range(cute.size(tTR_rO), vectorize=True):
tTR_rC[k] = tTR_rO[k].to(self.c_dtype)
# REGS→SMEM: store BF16 O_kt to sC
# For kt>0, we need to ADD to previous sC. But epilogue_smem_copy
# does a store, not an add. We need SMEM read-modify-write.
# SKIP for now — just do one-way epilogue per kt to separate
# GMEM buffers, then merge in Python.
pass # TODO: SMEM accumulation
else:
si_handle.release()
softmax_done_bar.arrive()
si_handle.release()
softmax_done_bar.arrive()
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()