FIX: O sub-tile count should be HEAD_DIM/corr_tile_size, not 128/corr_tile_size

This commit is contained in:
2026-05-22 19:06:52 +00:00
parent 48b24ba005
commit dcc64dd14d

View File

@@ -279,7 +279,7 @@ class FmhaV3RealSoftmax:
# O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max)
# Only for kt > 0 (first tile: no existing O to rescale)
if kt > 0:
n_corr = 128 // corr_tile_size
n_corr = HEAD_DIM // corr_tile_size
tTMrO_rs = cute.make_rmem_tensor(
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
)
@@ -321,7 +321,7 @@ class FmhaV3RealSoftmax:
# Final O normalization: O = O / row_sum
if row_sum != Float32(0.0):
inv_row_sum = Float32(1.0) / row_sum
n_corr = 128 // corr_tile_size
n_corr = HEAD_DIM // corr_tile_size
tTMrO_fn = cute.make_rmem_tensor(
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
)