Add O rescale with correction_rescale pattern + fix TMA to working diag pattern

This commit is contained in:
2026-05-22 19:51:53 +00:00
parent 0bdcdc0efd
commit f165257c50

View File

@@ -203,22 +203,19 @@ class FmhaV3StageCMulti:
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
# GMEM tile coordinate: use the cutlass.range induction variable kt
# directly. CuTeDSL's `cutlass.range` doesn't auto-detect a Python `+=`
# rebinding as a loop-carried iter_args update — the JIT traces the
# body once and captures whatever value `kv_coord` had at trace time,
# so an outer `kv_coord = Int32(0)` plus a `kv_coord += 1` inside the
# loop bakes 0 into every iteration's TMA descriptor at runtime.
# The induction variable IS the loop-carried state, properly tracked.
# Use the SAME pattern as the working test_fmha_v3_diag.py:
# kv_coord = Int32(0+0) + kv_coord += 1 in cutlass.range
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
kv_coord = Int32(0 + 0)
for kt in cutlass.range(self.s_k // 128, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kv_coord += 1
pk = cutlass.Boolean(1)
kvp.tail()
@@ -292,16 +289,34 @@ class FmhaV3StageCMulti:
# Per-tile softmax loop.
# Online softmax row_max/row_sum tracking is maintained, but the
# in-place TMEM O rescale (which would multiply existing O by
# exp2(old_max - new_max) before PV[kt]) is DISABLED — this is the
# correctness compromise for hand-paired TMEM atoms not working.
# The fix path is to integrate the rescale into the same paired
# tmem_load/smem_store epilogue pattern we use below for normalize.
# For now: kernel is correct when row_max growth across tiles is
# mild (typical for short n with random data); for very long n
# the missing rescale shows as accuracy drift.
# O rescale setup: same correction_rescale pattern as the final normalize
corr_tile_size_rs = 16
cO_rs = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO_rs = pv_thr.partition_C(cO_rs)
tOtO_rs_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size_rs)))
tOcO_rs_i_layout = cute.composition(tOcO_rs.layout, cute.make_layout((128, corr_tile_size_rs)))
tOtO_rs_i = cute.make_tensor(tOtO0.iterator, tOtO_rs_i_layout)
tOcO_rs_i = cute.make_tensor(tOcO_rs.iterator, tOcO_rs_i_layout)
tmem_load_o_rs_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype)
tmem_store_o_rs_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype)
tiled_tmem_load_o_rs = tcgen05.make_tmem_copy(tmem_load_o_rs_atom, tOtO_rs_i)
tiled_tmem_store_o_rs = tcgen05.make_tmem_copy(tmem_store_o_rs_atom, tOtO_rs_i)
thr_tmem_load_o_rs = tiled_tmem_load_o_rs.get_slice(sfw_idx)
thr_tmem_store_o_rs = tiled_tmem_store_o_rs.get_slice(sfw_idx)
tTMEM_LOAD_OtO_rs = thr_tmem_load_o_rs.partition_S(tOtO_rs_i)
tTMEM_LOAD_OcO_rs = thr_tmem_load_o_rs.partition_D(tOcO_rs_i)
tTMEM_STORE_OtO_rs = thr_tmem_store_o_rs.partition_D(tOtO_rs_i)
tTMrO_rs = cute.make_rmem_tensor(
(tTMEM_LOAD_OcO_rs.shape, 128 // corr_tile_size_rs), self.acc_dtype)
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
for kt in range(n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# Load S[kt]
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
@@ -329,6 +344,27 @@ class FmhaV3StageCMulti:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# O rescale: multiply existing O by acc_scale = exp2(old_max - new_max)
# Uses the same correction_rescale pattern verified for final normalize.
if kt > 0:
for ci in range(HEAD_DIM // corr_tile_size_rs):
tTMrO_rs_i_ = tTMrO_rs[None, ci]
tTMrO_rs_i_layout = cute.composition(
tTMrO_rs_i_.layout, cute.make_layout(tTMrO_rs.shape[0])
)
tTMrO_rs_i = cute.make_tensor(tTMrO_rs_i_.iterator, tTMrO_rs_i_layout)
tTMEM_LOAD_OtO_rs_i = cute.make_tensor(
tTMEM_LOAD_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_LOAD_OtO_rs.layout
)
tTMEM_STORE_OtO_rs_i = cute.make_tensor(
tTMEM_STORE_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_STORE_OtO_rs.layout
)
cute.copy(tiled_tmem_load_o_rs, tTMEM_LOAD_OtO_rs_i, tTMrO_rs_i)
for j in cutlass.range(cute.size(tTMrO_rs_i), vectorize=True):
tTMrO_rs_i[j] = tTMrO_rs_i[j] * acc_scale
cute.copy(tiled_tmem_store_o_rs, tTMrO_rs_i, tTMEM_STORE_OtO_rs_i)
cute.arch.fence_view_async_tmem_store()
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
# store BF16 P through the FP32-backed register bridge.
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)