diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 1c5b1fc5..13e8798a 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -302,18 +302,19 @@ class FmhaV3StageC: cute.arch.fence_view_async_tmem_store() # O rescale: if kt > 0, rescale O in TMEM by exp2((old_max - new_max) * scale_log2) - if kt > 0: - corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True) - o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size - for i in range(o_col_tiles): - tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout) - tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout) - tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO) - for k in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[k] = tTMrO[k] * corr_scale - cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i) - cute.arch.fence_view_async_tmem_store() + # TEMPORARILY DISABLED FOR DEBUGGING + # if kt > 0: + # corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True) + # o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size + # for i in range(o_col_tiles): + # tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout) + # tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout) + # tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) + # cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO) + # for k in cutlass.range(cute.size(tTMrO), vectorize=True): + # tTMrO[k] = tTMrO[k] * corr_scale + # cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i) + # cute.arch.fence_view_async_tmem_store() si_handle.release() softmax_done_bar.arrive()