From 95d3d1bf03f961301df444a7a3fbbc6d944543cb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 19:09:33 +0000 Subject: [PATCH] Disable O rescale + normalize, verify softmax P only --- tests/fmha_v3_real_softmax.py | 48 ++--------------------------------- 1 file changed, 2 insertions(+), 46 deletions(-) diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 7bacdc78..76480d27 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -276,29 +276,7 @@ class FmhaV3RealSoftmax: acc_scale = Float32(0.0) row_sum *= acc_scale - # O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max) - # Only for kt > 0 (first tile: no existing O to rescale) - if kt > 0: - n_corr = HEAD_DIM // corr_tile_size - tTMrO_rs = cute.make_rmem_tensor( - (tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype - ) - for ci in range(n_corr): - tTMrO_ci_ = tTMrO_rs[None, ci] - tTMrO_ci_layout = cute.composition( - tTMrO_ci_.layout, cute.make_layout(tTMrO_rs.shape[0]) - ) - tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout) - tTMEM_LOAD_OtO_ci = cute.make_tensor( - tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout - ) - tTMEM_STORE_OtO_ci = cute.make_tensor( - tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci) - for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True): - tTMrO_ci[j] = tTMrO_ci[j] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci) + # O rescale: DISABLED — skip for now # Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) @@ -318,29 +296,7 @@ class FmhaV3RealSoftmax: si_handle.release() softmax_done_bar.arrive() - # Final O normalization: O = O / row_sum - if row_sum != Float32(0.0): - inv_row_sum = Float32(1.0) / row_sum - n_corr = HEAD_DIM // corr_tile_size - tTMrO_fn = cute.make_rmem_tensor( - (tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype - ) - for ci in range(n_corr): - tTMrO_ci_ = tTMrO_fn[None, ci] - tTMrO_ci_layout = cute.composition( - tTMrO_ci_.layout, cute.make_layout(tTMrO_fn.shape[0]) - ) - tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout) - tTMEM_LOAD_OtO_ci = cute.make_tensor( - tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout - ) - tTMEM_STORE_OtO_ci = cute.make_tensor( - tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci) - for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True): - tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci) + # Final O normalization: DISABLED # Epilogue: TMEM -> SMEM -> GMEM via TMA store tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)