From 0d3caced47e08fe4a10dd1a58dffe006699b549e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:35:39 +0000 Subject: [PATCH] Add: O rescale (correction_rescale) in softmax loop + remove pk from TMA/MMA --- tests/unit/test_fmha_v3_stage_c.py | 52 +++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index f8549eff..f917ff9c 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -224,7 +224,6 @@ class FmhaV3StageCMulti: kvh = kvp.acquire_and_advance(pk) cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - pk = cutlass.Boolean(1) kvp.tail() # ===== MMA warp ===== @@ -233,11 +232,11 @@ class FmhaV3StageCMulti: if warp_idx == self.mma_warp_id: tmem.wait_for_alloc() qc.reset(); qh = qc.wait_and_advance(); qh.release() - kvc.reset(); pk = kvc.try_wait() - acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) + kvc.reset() + for kt in range(n_tiles): + kvh = kvc.wait_and_advance() acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) acc_pipe.producer_acquire(acc_st) for kt in range(n_kv_tiles): - kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1) sh = s_prod.acquire_and_advance() qk_mma.set(tcgen05.Field.ACCUMULATE, False) for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): @@ -290,6 +289,29 @@ class FmhaV3StageCMulti: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) + # O rescale setup: same correction_rescale pattern as final normalize. + # Uses paired Ld32x32bOp/St32x32bOp atoms with matching Repetition(16). + corr_tile_size = 16 + cO_corr = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) + tOcO_corr = pv_thr.partition_C(cO_corr) + tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size))) + tOcO_i_layout = cute.composition(tOcO_corr.layout, cute.make_layout((128, corr_tile_size))) + tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO_corr.iterator, tOcO_i_layout) + tmem_load_o_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tmem_store_o_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype) + tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i) + tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i) + thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx) + thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx) + tTMEM_LOAD_OtO = thr_tmem_load_o.partition_S(tOtO_i) + tTMEM_LOAD_OcO = thr_tmem_load_o.partition_D(tOcO_i) + tTMEM_STORE_OtO = thr_tmem_store_o.partition_D(tOtO_i) + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOAD_OcO.shape, 128 // corr_tile_size), self.acc_dtype) + row_max = -Float32.inf row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) @@ -334,6 +356,28 @@ class FmhaV3StageCMulti: acc_scale = Float32(0.0) row_sum *= acc_scale + # O rescale: multiply existing O by acc_scale = exp2(old_max - new_max) + # Uses the correction_rescale pattern (same paired atoms as final normalize). + # Must happen BEFORE softmax_done_bar.arrive() so MMA's PV[kt] sees rescaled O. + if kt > 0: + for ci in range(HEAD_DIM // corr_tile_size): + tTMrO_i_ = tTMrO[None, ci] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOAD_OtO_i = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_i = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i) + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i) + cute.arch.fence_view_async_tmem_store() + # Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum, # store BF16 P through the FP32-backed register bridge. rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)