Add: O rescale (correction_rescale) in softmax loop + remove pk from TMA/MMA

This commit is contained in:
2026-05-22 21:35:39 +00:00
parent c47d229e6a
commit 0d3caced47

View File

@@ -224,7 +224,6 @@ class FmhaV3StageCMulti:
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ===== MMA warp =====
@@ -233,11 +232,11 @@ class FmhaV3StageCMulti:
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
kvc.reset()
for kt in range(n_tiles):
kvh = kvc.wait_and_advance() acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
@@ -290,6 +289,29 @@ class FmhaV3StageCMulti:
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# O rescale setup: same correction_rescale pattern as final normalize.
# Uses paired Ld32x32bOp/St32x32bOp atoms with matching Repetition(16).
corr_tile_size = 16
cO_corr = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO_corr = pv_thr.partition_C(cO_corr)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO_corr.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO_corr.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
tmem_store_o_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
tTMEM_LOAD_OtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOAD_OcO = thr_tmem_load_o.partition_D(tOcO_i)
tTMEM_STORE_OtO = thr_tmem_store_o.partition_D(tOtO_i)
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOAD_OcO.shape, 128 // corr_tile_size), self.acc_dtype)
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
@@ -334,6 +356,28 @@ class FmhaV3StageCMulti:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# O rescale: multiply existing O by acc_scale = exp2(old_max - new_max)
# Uses the correction_rescale pattern (same paired atoms as final normalize).
# Must happen BEFORE softmax_done_bar.arrive() so MMA's PV[kt] sees rescaled O.
if kt > 0:
for ci in range(HEAD_DIM // corr_tile_size):
tTMrO_i_ = tTMrO[None, ci]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOAD_OtO_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i)
cute.arch.fence_view_async_tmem_store()
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
# store BF16 P through the FP32-backed register bridge.
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)