diff --git a/tests/fmha_v3_real_softmax.py b/tests/fmha_v3_real_softmax.py index 0767596a..d8b235b5 100644 --- a/tests/fmha_v3_real_softmax.py +++ b/tests/fmha_v3_real_softmax.py @@ -276,22 +276,11 @@ class FmhaV3RealSoftmax: acc_scale = Float32(0.0) row_sum *= acc_scale - # O rescale in TMEM: multiply existing O by acc_scale = exp2(old_max - new_max) - # Only for kt > 0 (first tile: no existing O to rescale) - if kt > 0: - n_corr = HEAD_DIM // corr_tile_size - for ci in range(n_corr): - tTMrO_rs = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype) - tTMEM_LOAD_OtO_ci = cute.make_tensor( - tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout - ) - tTMEM_STORE_OtO_ci = cute.make_tensor( - tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout - ) - cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_rs) - for j in cutlass.range(cute.size(tTMrO_rs), vectorize=True): - tTMrO_rs[j] = tTMrO_rs[j] * acc_scale - cute.copy(tiled_tmem_store_o, tTMrO_rs, tTMEM_STORE_OtO_ci) + # O rescale: DISABLED for NO-OP test + # if kt > 0: + # n_corr = HEAD_DIM // corr_tile_size + # for ci in range(n_corr): + # ... # Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype) rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)