From 572656e79b09e0f1e05ab988fc70a4cfd088af3a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 10:23:37 +0000 Subject: [PATCH] debug: disable O rescaling to test multi-tile pipeline baseline --- tests/unit/test_fmha_v3_stage_c_full.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 1c5b1fc5..13e8798a 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -302,18 +302,19 @@ class FmhaV3StageC: cute.arch.fence_view_async_tmem_store() # O rescale: if kt > 0, rescale O in TMEM by exp2((old_max - new_max) * scale_log2) - if kt > 0: - corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True) - o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size - for i in range(o_col_tiles): - tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout) - tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout) - tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO) - for k in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[k] = tTMrO[k] * corr_scale - cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i) - cute.arch.fence_view_async_tmem_store() + # TEMPORARILY DISABLED FOR DEBUGGING + # if kt > 0: + # corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True) + # o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size + # for i in range(o_col_tiles): + # tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout) + # tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout) + # tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) + # cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO) + # for k in cutlass.range(cute.size(tTMrO), vectorize=True): + # tTMrO[k] = tTMrO[k] * corr_scale + # cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i) + # cute.arch.fence_view_async_tmem_store() si_handle.release() softmax_done_bar.arrive()