D1.5: Revert broken TMEM round-trip O rescale, document as fundamentally broken

TMEM round-trip via Ld32x32bOp/St32x32bOp corrupts O accumulator data
even with CUTLASS correction_rescale pattern. All variants tested:
- Repetition(16) + composition (CUTLASS exact pattern) — BROKEN
- Repetition(32) + composition — BROKEN
- Repetition(16) raw layout (no composition) — BROKEN
Even NO-OP (multiply by 1.0) produces catastrophically wrong results.

Production path remains Python KV merge (cos 0.999998 for s_k up to 1024).
Next: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
This commit is contained in:
2026-05-26 20:55:16 +00:00
parent 42c5793add
commit afb93eae22
2 changed files with 176 additions and 91 deletions

View File

@@ -21,7 +21,7 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False, debug_noop_rescale=False):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False):
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
# When n_comp is None or 0, no offset (backward compatible).
@@ -58,8 +58,6 @@ class FmhaKernel:
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
self.q_stage = 1
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
self.debug_noop_rescale = debug_noop_rescale # D1.5 debug: force acc_scale=1.0 in O rescale
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
@@ -189,9 +187,8 @@ class FmhaKernel:
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
# D1.5: barrier for PV completion signal (MMA→softmax warps)
# MMA warp arrives after PV[kt] completes; softmax warps wait before O rescale.
pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id))
# D1.5: pv_done_bar for O rescale (currently unused — TMEM round-trip broken)
# pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
@@ -319,8 +316,7 @@ class FmhaKernel:
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh_v.release()
if const_expr(self.n_kv_tiles > 1):
pv_done_bar.arrive() # Signal softmax warps: PV done, O is ready for rescale
# pv_done_bar.arrive() # D1.5: unused — TMEM round-trip broken
final_o_bar.arrive()
else:
# Original pipeline path (hd≤256)
@@ -338,8 +334,6 @@ class FmhaKernel:
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
if const_expr(self.n_kv_tiles > 1):
cute.arch.fence_view_async_tmem_load() # Ensure rescaled O visible before PV[kt]
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
@@ -351,8 +345,7 @@ class FmhaKernel:
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
if const_expr(self.n_kv_tiles > 1):
pv_done_bar.arrive() # Signal softmax warps: PV done, O ready for rescale
# pv_done_bar.arrive() # D1.5: unused — TMEM round-trip broken
acc_pipe.producer_commit(acc_st); acc_st.advance()
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
@@ -409,53 +402,22 @@ class FmhaKernel:
scale_log2 = Float32(self.scale_softmax_log2)
# ============================================================
# D1.5: O RESCALE ATOMS (CUTLASS correction_rescale pattern)
# D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH
# =================================================
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts data (ratio = -11 billion).
# Instead, we use one-way TMEM→REGS→SMEM after each PV,
# accumulate in SMEM with acc_scale multiplication, and
# TMA store SMEM→GMEM after all kt iterations.
#
# For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store
# path works perfectly (cos=0.999998). The SMEM accumulator
# is only needed for n_kv_tiles > 1.
# ============================================================
# Pattern: both load and store atoms built from the SAME tOtO_i
# (composition-tiled from tOtO0), same Repetition(corr_tile_size).
# This is the exact pattern from CUTLASS reference fmha.py line 2123.
# The key insight: using composition() to re-tile tOtO into (128, corr_tile_size)
# sub-tiles, and building BOTH copies from the SAME tensor, ensures the
# column mappings agree on round-trip.
# ============================================================
corr_tile_size = 16 # Must be power of 2, divides head_dim
# Try both composition and raw layout
use_comp = True
if const_expr(use_comp):
tOtO_i_layout = cute.composition(
tOtO0.layout, cute.make_layout((128, corr_tile_size))
)
else:
tOtO_i_layout = tOtO0.layout
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
# Coordinate tensor for O (needed for partition_D of load)
cO = cute.make_identity_tensor((128, self.head_dim))
tOcO = pv_thr.partition_C(cO)
if const_expr(use_comp):
tOcO_i_layout = cute.composition(
tOcO.layout, cute.make_layout((128, corr_tile_size))
)
else:
tOcO_i_layout = tOcO.layout
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.qk_acc_dtype,
)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
tmem_store_o_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.qk_acc_dtype,
)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
# NOTE: The code below is the BROKEN TMEM round-trip approach.
# It's kept as reference but should NOT be used.
# The SMEM accumulator implementation is TODO.
# prev_acc_scale: unused, kept for clarity. acc_scale at kt is used
# to rescale O from kt=0..kt-1 before PV[kt].
@@ -559,40 +521,12 @@ class FmhaKernel:
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# D1.5: O rescale for kt > 0 — CUTLASS correction_rescale pattern.
# After computing acc_scale for this iteration, rescale the existing O
# in TMEM before the next PV GEMM adds to it.
# Must wait for PV[kt-1] to complete (MMA signals pv_done_bar).
if const_expr(self.n_kv_tiles > 1):
if kt > 0:
pv_done_bar.arrive_and_wait() # Wait for PV[kt-1]
# Rescale O: load, multiply by acc_scale, store back to TMEM.
# CUTLASS pattern: both copies use same tOtO_i (composition-tiled).
rescale_factor = acc_scale
if const_expr(self.debug_noop_rescale):
rescale_factor = Float32(1.0)
n_slices = self.head_dim // corr_tile_size
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, n_slices), self.qk_acc_dtype
)
for i in range(n_slices):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * rescale_factor
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED.
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts O accumulator data.
# Production path for multi-KV-tile: Python KV merge (cos 0.999998).
# Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
# n_kv_tiles=1 is the only supported path for in-kernel processing.
si_handle.release()
softmax_done_bar.arrive()

View File

@@ -0,0 +1,151 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
SMEM accumulator approach for multi-KV-tile O rescale.
Instead of TMEM round-trip (which corrupts data), we move O from TMEM
to SMEM after each PV GEMM via one-way epilogue, and accumulate in SMEM.
This avoids the D1.5 TMEM round-trip bug entirely.
Architecture:
- 6-warp specialization: 4 softmax+epilogue, 1 MMA, 1 TMA
- After PV[kt]: one-way TMEM→REGS→SMEM with acc_scale multiplication
- SMEM accumulator persists across kt iterations
- Final TMA store: SMEM→GMEM
Per-kt flow:
1. Softmax warps: compute P[kt], acc_scale[kt]
2. Signal softmax_done_bar
3. MMA warp: PV[kt] GEMM (ACCUMULATE=False, fresh TMEM)
4. Signal pv_done_bar
5. Softmax/epilogue warps: TMEM→REGS, acc_scale*O_acc + O_kt, REGS→SMEM
6. Repeat for next kt
7. After all kt: SMEM→GMEM via TMA
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.blackwell_helpers import get_smem_store_op
from cutlass.utils.gemm.sm100 import (
transform_partitioned_tensor_layout,
epilogue_tmem_copy_and_partition,
epilogue_smem_copy_and_partition,
)
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True,
num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False,
n_comp=None, apply_sink_bias=False):
self.n_comp = n_comp if n_comp is not None else 0
self.apply_sink_bias = apply_sink_bias
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256)
if head_dim > 256:
self.pv_n_tile = 128
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.num_query_heads = num_query_heads
self.batch_size = batch_size
self.normalize = normalize
self.apply_swa_mask = apply_swa_mask
self.is_causal = is_causal
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, self.k_tile)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0
if not self.use_smem_p:
self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
else:
self.tmem_p0_offset = -1
self.tmem_o0_offset = 0
s_cols = self.qk_mma_tiler[1]
o_cols = find_tmem_tensor_col_offset(tOtO)
total = max(s_cols, o_cols)
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(self.pv_n_tile, self.s_k, 1),
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,self.pv_n_tile), pv_source)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
if const_expr(swa_len is None):
swa_len = Int32(2147483647)
else:
swa_len = Int32(swa_len)
if const_expr(sink_bias is None):
sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,)))
if const_expr(row_sums is None):
row_sums = cute.make_tensor(lse.iterator, lse.layout)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
# ... rest of kernel to be implemented