From 22c2cf35aa010a4fec0a5cd0b6da9317222187ac Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 10:22:44 +0000 Subject: [PATCH] fix: revert to scaled row_max, use exp2(old_max - new_max) for O rescaling row_max is in scaled domain (s_val * scale_log2). The O rescaling should be exp2(old_max - new_max) without extra scale_log2 because the max values already include the scaling factor. --- tests/unit/test_fmha_v3_stage_c_full.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index b0bbda39..1c5b1fc5 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -273,7 +273,7 @@ class FmhaV3StageC: tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j]) + row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) row_max_safe = row_max if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0) @@ -303,7 +303,7 @@ class FmhaV3StageC: # O rescale: if kt > 0, rescale O in TMEM by exp2((old_max - new_max) * scale_log2) if kt > 0: - corr_scale = cute.math.exp2(scale_log2 * (old_row_max - row_max_safe), fastmath=True) + corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True) o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size for i in range(o_col_tiles): tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout)