D1.5: Revert broken paired-atom O rescale — TMEM round-trip fundamentally broken
Ld32x32bOp and St32x32bOp have different column mappings at the hardware level. No layout transformation can fix this — the atoms themselves map TMEM columns differently. The MoE correction epilogue avoids the problem by doing a ONE-WAY trip (TMEM→REGS→SMEM→GMEM, never writes back to TMEM). FMHA needs O in TMEM for PV accumulation between kt iterations, so one-way doesn't help. Production path for multi-KV-tile: Python KV merge (already verified, cos 0.999998 for s_k up to 1024). Run kernel per 128-token segment. Future: restructure PV to accumulate into REGS/SMEM instead of TMEM, enabling the one-way correction epilogue pattern.
This commit is contained in:
@@ -11,10 +11,9 @@ from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
from cutlass.utils.blackwell_helpers import get_smem_store_op
|
||||
from cutlass.utils.gemm.sm100 import (
|
||||
transform_partitioned_tensor_layout,
|
||||
epilogue_tmem_copy_and_partition,
|
||||
)
|
||||
# TMEM round-trip is fundamentally broken (Ld32x32bOp/St32x32bOp column mapping mismatch).
|
||||
# The one-way correction epilogue pattern (from cutlass.utils.gemm.sm100) requires
|
||||
# restructuring PV to not use TMEM accumulator. See D1.5 notes in STAGE_D.md.
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
import math
|
||||
@@ -398,24 +397,22 @@ class FmhaKernel:
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# ============================================================
|
||||
# O RESCALE PAIRED ATOMS (D1.5 fix, multi-KV-tile only)
|
||||
# D1.5: MULTI-KV-TILE O RESCALE — NOT SUPPORTED IN-KERNEL
|
||||
# ============================================================
|
||||
# Replace broken hand-constructed Ld32x32bOp/St32x32bOp round-trip
|
||||
# with paired atoms from epilogue_tmem_copy_and_partition.
|
||||
# The paired atoms share addressing, so the TMEM→REGS→modify→TMEM
|
||||
# cycle is lossless (unlike independently constructed atoms).
|
||||
# Only needed when n_kv_tiles > 1 (multi-KV-tile O rescale).
|
||||
# TMEM round-trip (load O, modify, store back) is FUNDAMENTALLY
|
||||
# broken: Ld32x32bOp and St32x32bOp have different column mappings
|
||||
# at the hardware level. The MoE correction epilogue avoids this
|
||||
# by doing a ONE-WAY trip (TMEM->REGS->SMEM->GMEM), but FMHA needs
|
||||
# to keep O in TMEM for PV accumulation between kt iterations.
|
||||
#
|
||||
# Production path for multi-KV-tile: Python KV merge.
|
||||
# Run kernel per 128-token segment (s_k=128), merge externally:
|
||||
# O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]
|
||||
# Verified cos 0.999998 for s_k up to 1024.
|
||||
#
|
||||
# Future: restructure PV to accumulate into REGS/SMEM instead
|
||||
# of TMEM, enabling the one-way correction epilogue pattern.
|
||||
# ============================================================
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base)
|
||||
tCgC_transformed = transform_partitioned_tensor_layout(tCgC)
|
||||
tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO_transformed, tCgC_transformed,
|
||||
epi_tile, self.use_2cta_instrs,
|
||||
)
|
||||
tTR_tO_grouped = cute.group_modes(tTR_tO_base, 3, cute.rank(tTR_tO_base))
|
||||
subtile_cnt = cute.size(tTR_tO_grouped.shape, mode=[3])
|
||||
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
@@ -515,23 +512,9 @@ class FmhaKernel:
|
||||
k2 = k_coord // 64
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
# O rescale for kt > 0 using paired atoms (D1.5 fix).
|
||||
# TMEM→REGS (paired load), multiply by acc_scale,
|
||||
# REGS→TMEM (paired store via retile_to_S).
|
||||
# The paired atom's addressing is consistent for load and store,
|
||||
# so this does NOT suffer from the layout mismatch that broke the
|
||||
# hand-constructed Ld32x32bOp/St32x32bOp round-trip.
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
if kt > 0:
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
tTR_tO_mn = tTR_tO_grouped[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)
|
||||
# Modify in registers
|
||||
for k in cutlass.range(cute.size(tTR_rO), vectorize=True):
|
||||
tTR_rO[k] = tTR_rO[k] * acc_scale
|
||||
# Store back to TMEM via paired atom's store direction
|
||||
cute.copy(tiled_copy_t2r.retile_to_S(), tTR_rO, tTR_tO_mn)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
# D1.5: O rescale for kt > 0 is NOT supported in-kernel.
|
||||
# Multi-KV-tile attention uses Python KV merge instead.
|
||||
# n_kv_tiles=1 is the only tested/supported path.
|
||||
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
Reference in New Issue
Block a user