DEBUG: disable O rescale to isolate NaN cause

This commit is contained in:
2026-05-22 18:05:46 +00:00
parent 1c3970fe58
commit 02d993ecac

View File

@@ -354,27 +354,12 @@ class FmhaV3StageC:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# O rescale for kt > 0. Reads O written by MMA's PV[kt-1];
# s_cons.wait_and_advance acquires on MMA's S[kt] commit (sequenced
# after PV[kt-1]), but add an explicit tmem_load fence for safety.
if kt > 0:
cute.arch.fence_view_async_tmem_load()
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + i * corr_tile_size,
tTMEM_LOAD_OtO.layout,
)
tTMEM_STORE_O_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + i * corr_tile_size,
tTMEM_STORE_OtO.layout,
)
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
cute.arch.fence_view_async_tmem_store()
# O rescale TEMPORARILY DISABLED for debugging NaN
# if kt > 0:
# cute.arch.fence_view_async_tmem_load()
# for i in range(o_col_tiles):
# ...
# cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()