FIX: acc_scale was double-multiplying by scale_log2
row_max is already in log2(scaled) space (S * scale_log2), so old_row_max - row_max_safe is the correct exponent for exp2. The old code computed exp2(scale_log2 * (old_row_max - row_max_safe)) which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong.
This commit is contained in:
@@ -316,8 +316,10 @@ class FmhaV3StageC:
|
||||
row_max_safe = Float32(0.0)
|
||||
|
||||
# acc_scale for both row_sum rescale and O rescale.
|
||||
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
|
||||
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
# row_max is already in log2(scaled) space (S * scale_log2),
|
||||
# so the difference old_row_max - row_max_safe is already the
|
||||
# correct exponent for exp2. No extra scale_log2 factor.
|
||||
acc_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True)
|
||||
if old_row_max == -cutlass.Float32.inf:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
Reference in New Issue
Block a user