Debug: add row_sum/inv_row_sum printf at final normalize
This commit is contained in:
@@ -332,6 +332,11 @@ class FmhaV3StageCMulti:
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Pass 1: update row_max (in log2-domain, fused with scale).
|
||||
# Compute O rescale factor and update row_sum.
|
||||
# At kt=0, old_row_max is -inf, so acc_scale = 0 — but
|
||||
# row_sum starts at 0 too, so row_sum *= 0 is fine (0*0=0).
|
||||
# The O rescale (O *= acc_scale) must be skipped at kt=0
|
||||
# because it would zero out the first tile's contribution.
|
||||
old_row_max = row_max
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
@@ -398,6 +403,7 @@ class FmhaV3StageCMulti:
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
cute.printf("FINAL row_sum=%f inv_row_sum=%f\n", row_sum, inv_row_sum)
|
||||
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
|
||||
Reference in New Issue
Block a user