diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 94b7a460..2dea592c 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -316,8 +316,10 @@ class FmhaV3StageC: row_max_safe = Float32(0.0) # acc_scale for both row_sum rescale and O rescale. - acc_scale_ = scale_log2 * (old_row_max - row_max_safe) - acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + # row_max is already in log2(scaled) space (S * scale_log2), + # so the difference old_row_max - row_max_safe is already the + # correct exponent for exp2. No extra scale_log2 factor. + acc_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True) if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0) row_sum *= acc_scale