debug: disable O rescaling to test multi-tile pipeline baseline

This commit is contained in:
2026-05-22 10:23:37 +00:00
parent 8ce257150e
commit 572656e79b

View File

@@ -302,18 +302,19 @@ class FmhaV3StageC:
cute.arch.fence_view_async_tmem_store()
# O rescale: if kt > 0, rescale O in TMEM by exp2((old_max - new_max) * scale_log2)
if kt > 0:
corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True)
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout)
tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout)
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * corr_scale
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
cute.arch.fence_view_async_tmem_store()
# TEMPORARILY DISABLED FOR DEBUGGING
# if kt > 0:
# corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True)
# o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
# for i in range(o_col_tiles):
# tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout)
# tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout)
# tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
# cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
# for k in cutlass.range(cute.size(tTMrO), vectorize=True):
# tTMrO[k] = tTMrO[k] * corr_scale
# cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
# cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()