Disable O rescale + normalize, verify softmax P only
This commit is contained in:
@@ -276,29 +276,7 @@ class FmhaV3RealSoftmax:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max)
|
||||
# Only for kt > 0 (first tile: no existing O to rescale)
|
||||
if kt > 0:
|
||||
n_corr = HEAD_DIM // corr_tile_size
|
||||
tTMrO_rs = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
|
||||
)
|
||||
for ci in range(n_corr):
|
||||
tTMrO_ci_ = tTMrO_rs[None, ci]
|
||||
tTMrO_ci_layout = cute.composition(
|
||||
tTMrO_ci_.layout, cute.make_layout(tTMrO_rs.shape[0])
|
||||
)
|
||||
tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout)
|
||||
tTMEM_LOAD_OtO_ci = cute.make_tensor(
|
||||
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
|
||||
)
|
||||
tTMEM_STORE_OtO_ci = cute.make_tensor(
|
||||
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci)
|
||||
for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True):
|
||||
tTMrO_ci[j] = tTMrO_ci[j] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci)
|
||||
# O rescale: DISABLED — skip for now
|
||||
|
||||
# Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
@@ -318,29 +296,7 @@ class FmhaV3RealSoftmax:
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# Final O normalization: O = O / row_sum
|
||||
if row_sum != Float32(0.0):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
n_corr = HEAD_DIM // corr_tile_size
|
||||
tTMrO_fn = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
|
||||
)
|
||||
for ci in range(n_corr):
|
||||
tTMrO_ci_ = tTMrO_fn[None, ci]
|
||||
tTMrO_ci_layout = cute.composition(
|
||||
tTMrO_ci_.layout, cute.make_layout(tTMrO_fn.shape[0])
|
||||
)
|
||||
tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout)
|
||||
tTMEM_LOAD_OtO_ci = cute.make_tensor(
|
||||
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
|
||||
)
|
||||
tTMEM_STORE_OtO_ci = cute.make_tensor(
|
||||
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci)
|
||||
for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True):
|
||||
tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci)
|
||||
# Final O normalization: DISABLED
|
||||
|
||||
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
|
||||
Reference in New Issue
Block a user