diff --git a/tests/fmha_v3_stage_c_example7.py b/tests/fmha_v3_stage_c_example7.py index 22a0d4b9..7f5bb9c5 100644 --- a/tests/fmha_v3_stage_c_example7.py +++ b/tests/fmha_v3_stage_c_example7.py @@ -289,7 +289,7 @@ class FmhaV3StageCMulti: # Per-tile softmax loop. # Online softmax row_max/row_sum tracking is maintained, but the # in-place TMEM O rescale (which would multiply existing O by - # O rescale + final normalize setup: single set of correction_rescale tensors + # O corr setup: DISABLED to debug n=128 regression corr_tile_size = 16 cO_corr = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1])) tOcO_corr = pv_thr.partition_C(cO_corr) @@ -311,6 +311,16 @@ class FmhaV3StageCMulti: tTMrO = cute.make_rmem_tensor( (tTMEM_LOAD_OcO.shape, 128 // corr_tile_size), self.acc_dtype) + # DISABLED: O rescale (kt > 0) + # if kt > 0: + # for ci in range(HEAD_DIM // corr_tile_size): + # ... + + # DISABLED: Final O normalize (1/row_sum) + # inv_row_sum = Float32(1.0) / row_sum + # for i in range(HEAD_DIM // corr_tile_size): + # ... + row_max = -Float32.inf row_sum = Float32(0.0) scale_log2 = Float32(self.scale_softmax_log2) @@ -344,25 +354,7 @@ class FmhaV3StageCMulti: acc_scale = Float32(0.0) row_sum *= acc_scale - # O rescale: multiply existing O by acc_scale = exp2(old_max - new_max) - if kt > 0: - for ci in range(HEAD_DIM // corr_tile_size): - tTMrO_i_ = tTMrO[None, ci] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOAD_OtO_i = cute.make_tensor( - tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout - ) - tTMEM_STORE_OtO_i = cute.make_tensor( - tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i) - cute.arch.fence_view_async_tmem_store() + # O rescale: DISABLED (debugging n=128 regression) # Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum, # store BF16 P through the FP32-backed register bridge. @@ -404,29 +396,7 @@ class FmhaV3StageCMulti: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === Final O normalization: 1/row_sum === - # Reuses the same corr_tile_size + tiled_tmem_load_o/store_o from above. - inv_row_sum = Float32(1.0) / row_sum - - for i in range(HEAD_DIM // corr_tile_size): - tTMrO_i_ = tTMrO[None, i] - tTMrO_i_layout = cute.composition( - tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) - ) - tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) - tTMEM_LOAD_OtO_i = cute.make_tensor( - tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout - ) - tTMEM_STORE_OtO_i = cute.make_tensor( - tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout - ) - - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * inv_row_sum - cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i) - - cute.arch.fence_view_async_tmem_store() + # === Final O normalization: DISABLED (debugging) === # Standard epilogue: TMEM → SMEM → GMEM via TMA store. # O in TMEM is now scaled by 1/row_sum.