Add O rescale with correction_rescale pattern + fix TMA to working diag pattern
This commit is contained in:
@@ -203,22 +203,19 @@ class FmhaV3StageCMulti:
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp =====
|
||||
# GMEM tile coordinate: use the cutlass.range induction variable kt
|
||||
# directly. CuTeDSL's `cutlass.range` doesn't auto-detect a Python `+=`
|
||||
# rebinding as a loop-carried iter_args update — the JIT traces the
|
||||
# body once and captures whatever value `kv_coord` had at trace time,
|
||||
# so an outer `kv_coord = Int32(0)` plus a `kv_coord += 1` inside the
|
||||
# loop bakes 0 into every iteration's TMA descriptor at runtime.
|
||||
# The induction variable IS the loop-carried state, properly tracked.
|
||||
# Use the SAME pattern as the working test_fmha_v3_diag.py:
|
||||
# kv_coord = Int32(0+0) + kv_coord += 1 in cutlass.range
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
|
||||
kv_coord = Int32(0 + 0)
|
||||
for kt in cutlass.range(self.s_k // 128, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord += 1
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
@@ -292,16 +289,34 @@ class FmhaV3StageCMulti:
|
||||
# Per-tile softmax loop.
|
||||
# Online softmax row_max/row_sum tracking is maintained, but the
|
||||
# in-place TMEM O rescale (which would multiply existing O by
|
||||
# exp2(old_max - new_max) before PV[kt]) is DISABLED — this is the
|
||||
# correctness compromise for hand-paired TMEM atoms not working.
|
||||
# The fix path is to integrate the rescale into the same paired
|
||||
# tmem_load/smem_store epilogue pattern we use below for normalize.
|
||||
# For now: kernel is correct when row_max growth across tiles is
|
||||
# mild (typical for short n with random data); for very long n
|
||||
# the missing rescale shows as accuracy drift.
|
||||
# O rescale setup: same correction_rescale pattern as the final normalize
|
||||
corr_tile_size_rs = 16
|
||||
cO_rs = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO_rs = pv_thr.partition_C(cO_rs)
|
||||
tOtO_rs_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size_rs)))
|
||||
tOcO_rs_i_layout = cute.composition(tOcO_rs.layout, cute.make_layout((128, corr_tile_size_rs)))
|
||||
tOtO_rs_i = cute.make_tensor(tOtO0.iterator, tOtO_rs_i_layout)
|
||||
tOcO_rs_i = cute.make_tensor(tOcO_rs.iterator, tOcO_rs_i_layout)
|
||||
tmem_load_o_rs_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype)
|
||||
tmem_store_o_rs_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size_rs)), self.acc_dtype)
|
||||
tiled_tmem_load_o_rs = tcgen05.make_tmem_copy(tmem_load_o_rs_atom, tOtO_rs_i)
|
||||
tiled_tmem_store_o_rs = tcgen05.make_tmem_copy(tmem_store_o_rs_atom, tOtO_rs_i)
|
||||
thr_tmem_load_o_rs = tiled_tmem_load_o_rs.get_slice(sfw_idx)
|
||||
thr_tmem_store_o_rs = tiled_tmem_store_o_rs.get_slice(sfw_idx)
|
||||
tTMEM_LOAD_OtO_rs = thr_tmem_load_o_rs.partition_S(tOtO_rs_i)
|
||||
tTMEM_LOAD_OcO_rs = thr_tmem_load_o_rs.partition_D(tOcO_rs_i)
|
||||
tTMEM_STORE_OtO_rs = thr_tmem_store_o_rs.partition_D(tOtO_rs_i)
|
||||
tTMrO_rs = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO_rs.shape, 128 // corr_tile_size_rs), self.acc_dtype)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
for kt in range(n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
|
||||
# Load S[kt]
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
@@ -329,6 +344,27 @@ class FmhaV3StageCMulti:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# O rescale: multiply existing O by acc_scale = exp2(old_max - new_max)
|
||||
# Uses the same correction_rescale pattern verified for final normalize.
|
||||
if kt > 0:
|
||||
for ci in range(HEAD_DIM // corr_tile_size_rs):
|
||||
tTMrO_rs_i_ = tTMrO_rs[None, ci]
|
||||
tTMrO_rs_i_layout = cute.composition(
|
||||
tTMrO_rs_i_.layout, cute.make_layout(tTMrO_rs.shape[0])
|
||||
)
|
||||
tTMrO_rs_i = cute.make_tensor(tTMrO_rs_i_.iterator, tTMrO_rs_i_layout)
|
||||
tTMEM_LOAD_OtO_rs_i = cute.make_tensor(
|
||||
tTMEM_LOAD_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_LOAD_OtO_rs.layout
|
||||
)
|
||||
tTMEM_STORE_OtO_rs_i = cute.make_tensor(
|
||||
tTMEM_STORE_OtO_rs.iterator + ci * corr_tile_size_rs, tTMEM_STORE_OtO_rs.layout
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o_rs, tTMEM_LOAD_OtO_rs_i, tTMrO_rs_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_rs_i), vectorize=True):
|
||||
tTMrO_rs_i[j] = tTMrO_rs_i[j] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o_rs, tTMrO_rs_i, tTMEM_STORE_OtO_rs_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
|
||||
# store BF16 P through the FP32-backed register bridge.
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
|
||||
Reference in New Issue
Block a user