DEBUG: disable O rescale + normalize, test if corr setup alone causes regression

This commit is contained in:
2026-05-22 19:57:53 +00:00
parent 4fe88a6994
commit ad6ea439a4

View File

@@ -289,7 +289,7 @@ class FmhaV3StageCMulti:
# Per-tile softmax loop.
# Online softmax row_max/row_sum tracking is maintained, but the
# in-place TMEM O rescale (which would multiply existing O by
# O rescale + final normalize setup: single set of correction_rescale tensors
# O corr setup: DISABLED to debug n=128 regression
corr_tile_size = 16
cO_corr = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO_corr = pv_thr.partition_C(cO_corr)
@@ -311,6 +311,16 @@ class FmhaV3StageCMulti:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOAD_OcO.shape, 128 // corr_tile_size), self.acc_dtype)
# DISABLED: O rescale (kt > 0)
# if kt > 0:
# for ci in range(HEAD_DIM // corr_tile_size):
# ...
# DISABLED: Final O normalize (1/row_sum)
# inv_row_sum = Float32(1.0) / row_sum
# for i in range(HEAD_DIM // corr_tile_size):
# ...
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
@@ -344,25 +354,7 @@ class FmhaV3StageCMulti:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# O rescale: multiply existing O by acc_scale = exp2(old_max - new_max)
if kt > 0:
for ci in range(HEAD_DIM // corr_tile_size):
tTMrO_i_ = tTMrO[None, ci]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOAD_OtO_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i)
cute.arch.fence_view_async_tmem_store()
# O rescale: DISABLED (debugging n=128 regression)
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
# store BF16 P through the FP32-backed register bridge.
@@ -404,29 +396,7 @@ class FmhaV3StageCMulti:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# === Final O normalization: 1/row_sum ===
# Reuses the same corr_tile_size + tiled_tmem_load_o/store_o from above.
inv_row_sum = Float32(1.0) / row_sum
for i in range(HEAD_DIM // corr_tile_size):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOAD_OtO_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_OtO_i)
cute.arch.fence_view_async_tmem_store()
# === Final O normalization: DISABLED (debugging) ===
# Standard epilogue: TMEM → SMEM → GMEM via TMA store.
# O in TMEM is now scaled by 1/row_sum.