Disable O rescale too for NO-OP test

This commit is contained in:
2026-05-22 19:13:43 +00:00
parent 62529df638
commit d9f7443b77

View File

@@ -276,22 +276,11 @@ class FmhaV3RealSoftmax:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max)
# Only for kt > 0 (first tile: no existing O to rescale)
if kt > 0:
n_corr = HEAD_DIM // corr_tile_size
for ci in range(n_corr):
tTMrO_rs = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
tTMEM_LOAD_OtO_ci = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
)
tTMEM_STORE_OtO_ci = cute.make_tensor(
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_rs)
for j in cutlass.range(cute.size(tTMrO_rs), vectorize=True):
tTMrO_rs[j] = tTMrO_rs[j] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_rs, tTMEM_STORE_OtO_ci)
# O rescale: DISABLED for NO-OP test
# if kt > 0:
# n_corr = HEAD_DIM // corr_tile_size
# for ci in range(n_corr):
# ...
# Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)