Revert "D1.5: WIP SMEM accumulator — framework in place, accumulation logic TODO"
This reverts commit 72d88af400.
This commit is contained in:
@@ -410,29 +410,25 @@ class FmhaKernel:
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# ============================================================
|
||||
# D1.5: SMEM ACCUMULATOR for multi-KV-tile O rescale
|
||||
# D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH
|
||||
# =================================================
|
||||
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
|
||||
# even NO-OP round-trip corrupts data (ratio = -11 billion).
|
||||
# Instead, we use one-way TMEM→REGS→SMEM after each PV,
|
||||
# accumulate in SMEM with acc_scale multiplication, and
|
||||
# TMA store SMEM→GMEM after all kt iterations.
|
||||
#
|
||||
# For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store
|
||||
# path works perfectly (cos=0.999998). The SMEM accumulator
|
||||
# is only needed for n_kv_tiles > 1.
|
||||
# ============================================================
|
||||
# After each PV[kt], move O from TMEM to SMEM via one-way epilogue.
|
||||
# Accumulate in SMEM with acc_scale multiplication.
|
||||
# TMEM round-trip is fundamentally broken — one-way only.
|
||||
# ============================================================
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
# Build one-way TMEM→REGS→SMEM epilogue pipeline
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base)
|
||||
tCgC_transformed = transform_partitioned_tensor_layout(tCgC)
|
||||
|
||||
# TMEM→REGS (paired atoms from epilogue infrastructure)
|
||||
tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO_transformed, tCgC_transformed,
|
||||
epi_tile, self.use_2cta_instrs,
|
||||
)
|
||||
# REGS→SMEM (paired atoms)
|
||||
tTR_rC = cute.make_rmem_tensor(tTR_rO.shape, self.c_dtype)
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
|
||||
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC,
|
||||
)
|
||||
# NOTE: The code below is the BROKEN TMEM round-trip approach.
|
||||
# It's kept as reference but should NOT be used.
|
||||
# The SMEM accumulator implementation is TODO.
|
||||
|
||||
# prev_acc_scale: unused, kept for clarity. acc_scale at kt is used
|
||||
# to rescale O from kt=0..kt-1 before PV[kt].
|
||||
prev_acc_scale = Float32(0.0)
|
||||
|
||||
for kt in range(self.n_kv_tiles):
|
||||
@@ -533,54 +529,15 @@ class FmhaKernel:
|
||||
k2 = k_coord // 64
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
# D1.5: SMEM accumulator epilogue for kt > 0.
|
||||
# After signaling P ready, wait for PV[kt] to complete,
|
||||
# then move O from TMEM to SMEM with acc_scale accumulation.
|
||||
# TMEM round-trip is broken — one-way only (TMEM→REGS→SMEM).
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive() # Signal MMA: P[kt] ready
|
||||
# D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED.
|
||||
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
|
||||
# even NO-OP round-trip corrupts O accumulator data.
|
||||
# Production path for multi-KV-tile: Python KV merge (cos 0.999998).
|
||||
# Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
|
||||
# n_kv_tiles=1 is the only supported path for in-kernel processing.
|
||||
|
||||
pv_done_bar.arrive_and_wait() # Wait for PV[kt] to complete
|
||||
|
||||
# One-way TMEM→REGS: load O_kt from TMEM
|
||||
tTR_tO_mn = tTR_tO_base[(None, None, None, 0)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# In REGS: multiply by acc_scale (rescale previous accumulation)
|
||||
# acc_scale for kt=0 is 0.0 (old_row_max = -inf),
|
||||
# so O_acc = 0 * O_acc + O_0 = O_0 (first iteration).
|
||||
# For kt>0: O_acc = acc_scale * O_acc + O_kt.
|
||||
# But we can't load the previous SMEM accumulator into the same
|
||||
# register buffer. We need to go: TMEM→REGS, then SMEM↔REGS.
|
||||
#
|
||||
# Simpler: store O_kt to SMEM (BF16), then do a separate
|
||||
# SMEM-level accumulation pass. But this loses FP32 precision.
|
||||
#
|
||||
# Correct: convert FP32 O_kt to BF16, store to sC via r2s copy.
|
||||
# Then in a separate loop, load BF16 from sC, multiply by
|
||||
# acc_scale, add to FP32 SMEM accumulator. But this requires
|
||||
# an extra SMEM buffer.
|
||||
#
|
||||
# For NOW: use the MoE pattern — convert to BF16 in registers,
|
||||
# store to SMEM, then TMA to GMEM. The accumulation happens
|
||||
# in a separate kernel or in Python (KV merge).
|
||||
# This is a stepping stone to full SMEM accumulation.
|
||||
|
||||
# Convert FP32 O to BF16 in registers
|
||||
for k in cutlass.range(cute.size(tTR_rO), vectorize=True):
|
||||
tTR_rC[k] = tTR_rO[k].to(self.c_dtype)
|
||||
|
||||
# REGS→SMEM: store BF16 O_kt to sC
|
||||
# For kt>0, we need to ADD to previous sC. But epilogue_smem_copy
|
||||
# does a store, not an add. We need SMEM read-modify-write.
|
||||
# SKIP for now — just do one-way epilogue per kt to separate
|
||||
# GMEM buffers, then merge in Python.
|
||||
pass # TODO: SMEM accumulation
|
||||
else:
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
Reference in New Issue
Block a user