FIX: O sub-tile count should be HEAD_DIM/corr_tile_size, not 128/corr_tile_size
This commit is contained in:
@@ -279,7 +279,7 @@ class FmhaV3RealSoftmax:
|
||||
# O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max)
|
||||
# Only for kt > 0 (first tile: no existing O to rescale)
|
||||
if kt > 0:
|
||||
n_corr = 128 // corr_tile_size
|
||||
n_corr = HEAD_DIM // corr_tile_size
|
||||
tTMrO_rs = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
|
||||
)
|
||||
@@ -321,7 +321,7 @@ class FmhaV3RealSoftmax:
|
||||
# Final O normalization: O = O / row_sum
|
||||
if row_sum != Float32(0.0):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
n_corr = 128 // corr_tile_size
|
||||
n_corr = HEAD_DIM // corr_tile_size
|
||||
tTMrO_fn = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user