Disable O rescale + normalize, verify softmax P only

This commit is contained in:
2026-05-22 19:09:33 +00:00
parent 69ec2ce7e5
commit 95d3d1bf03

View File

@@ -276,29 +276,7 @@ class FmhaV3RealSoftmax:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max)
# Only for kt > 0 (first tile: no existing O to rescale)
if kt > 0:
n_corr = HEAD_DIM // corr_tile_size
tTMrO_rs = cute.make_rmem_tensor(
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
)
for ci in range(n_corr):
tTMrO_ci_ = tTMrO_rs[None, ci]
tTMrO_ci_layout = cute.composition(
tTMrO_ci_.layout, cute.make_layout(tTMrO_rs.shape[0])
)
tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout)
tTMEM_LOAD_OtO_ci = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_ci = cute.make_tensor(
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci)
for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True):
tTMrO_ci[j] = tTMrO_ci[j] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci)
# O rescale: DISABLED — skip for now
# Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
@@ -318,29 +296,7 @@ class FmhaV3RealSoftmax:
si_handle.release()
softmax_done_bar.arrive()
# Final O normalization: O = O / row_sum
if row_sum != Float32(0.0):
inv_row_sum = Float32(1.0) / row_sum
n_corr = HEAD_DIM // corr_tile_size
tTMrO_fn = cute.make_rmem_tensor(
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
)
for ci in range(n_corr):
tTMrO_ci_ = tTMrO_fn[None, ci]
tTMrO_ci_layout = cute.composition(
tTMrO_ci_.layout, cute.make_layout(tTMrO_fn.shape[0])
)
tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout)
tTMEM_LOAD_OtO_ci = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_ci = cute.make_tensor(
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci)
for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True):
tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci)
# Final O normalization: DISABLED
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)