diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 76480d27..230b5281 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -229,7 +229,7 @@ class FmhaV3RealSoftmax: # O normalize setup: sub-tile O for TMEM read-modify-write cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) tOcO = pv_thr.partition_C(cO) - corr_tile_size = 16 + corr_tile_size = 32 tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size))) tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size))) tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout) @@ -276,8 +276,22 @@ class FmhaV3RealSoftmax: acc_scale = Float32(0.0) row_sum *= acc_scale - # O rescale: DISABLED — skip for now - + # O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max) + # Only for kt > 0 (first tile: no existing O to rescale) + if kt > 0: + n_corr = HEAD_DIM // corr_tile_size + for ci in range(n_corr): + tTMrO_rs = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) + tTMEM_LOAD_OtO_ci = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_ci = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_rs) + for j in cutlass.range(cute.size(tTMrO_rs), vectorize=True): + tTMrO_rs[j] = tTMrO_rs[j] * acc_scale + cute.copy(tiled_tmem_store_o, tTMrO_rs, tTMEM_STORE_OtO_ci) # Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout) @@ -296,7 +310,22 @@ class FmhaV3RealSoftmax: si_handle.release() softmax_done_bar.arrive() - # Final O normalization: DISABLED + # Final O normalization: O = O / row_sum + if row_sum != Float32(0.0): + inv_row_sum = Float32(1.0) / row_sum + n_corr = HEAD_DIM // corr_tile_size + for ci in range(n_corr): + tTMrO_fn = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) + tTMEM_LOAD_OtO_ci = cute.make_tensor( + tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout + ) + tTMEM_STORE_OtO_ci = cute.make_tensor( + tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_fn) + for j in cutlass.range(cute.size(tTMrO_fn), vectorize=True): + tTMrO_fn[j] = tTMrO_fn[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_fn, tTMEM_STORE_OtO_ci) # Epilogue: TMEM -> SMEM -> GMEM via TMA store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)