Debug: add row_sum/inv_row_sum printf at final normalize

This commit is contained in:
2026-05-23 00:34:38 +00:00
parent c2ff8e072e
commit d511ebe387

View File

@@ -332,6 +332,11 @@ class FmhaV3StageCMulti:
cute.arch.fence_view_async_tmem_load()
# Pass 1: update row_max (in log2-domain, fused with scale).
# Compute O rescale factor and update row_sum.
# At kt=0, old_row_max is -inf, so acc_scale = 0 — but
# row_sum starts at 0 too, so row_sum *= 0 is fine (0*0=0).
# The O rescale (O *= acc_scale) must be skipped at kt=0
# because it would zero out the first tile's contribution.
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
@@ -398,6 +403,7 @@ class FmhaV3StageCMulti:
# === Final O normalization: O *= 1/row_sum ===
inv_row_sum = Float32(1.0) / row_sum
cute.printf("FINAL row_sum=%f inv_row_sum=%f\n", row_sum, inv_row_sum)
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype