From b50968dfafcf0e675e81ae35bcbc8e334846af38 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 16:42:45 +0000 Subject: [PATCH] FIX: acc_scale was double-multiplying by scale_log2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit row_max is already in log2(scaled) space (S * scale_log2), so old_row_max - row_max_safe is the correct exponent for exp2. The old code computed exp2(scale_log2 * (old_row_max - row_max_safe)) which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong. --- tests/unit/test_fmha_v3_stage_c_full.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 94b7a460..2dea592c 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -316,8 +316,10 @@ class FmhaV3StageC: row_max_safe = Float32(0.0) # acc_scale for both row_sum rescale and O rescale. - acc_scale_ = scale_log2 * (old_row_max - row_max_safe) - acc_scale = cute.math.exp2(acc_scale_, fastmath=True) + # row_max is already in log2(scaled) space (S * scale_log2), + # so the difference old_row_max - row_max_safe is already the + # correct exponent for exp2. No extra scale_log2 factor. + acc_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True) if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0) row_sum *= acc_scale