FIX: acc_scale was double-multiplying by scale_log2

row_max is already in log2(scaled) space (S * scale_log2), so
old_row_max - row_max_safe is the correct exponent for exp2.
The old code computed exp2(scale_log2 * (old_row_max - row_max_safe))
which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong.
This commit is contained in:
2026-05-22 16:42:45 +00:00
parent c9fe26a5fc
commit b50968dfaf

View File

@@ -316,8 +316,10 @@ class FmhaV3StageC:
row_max_safe = Float32(0.0)
# acc_scale for both row_sum rescale and O rescale.
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
# row_max is already in log2(scaled) space (S * scale_log2),
# so the difference old_row_max - row_max_safe is already the
# correct exponent for exp2. No extra scale_log2 factor.
acc_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale