From e85d50dc3be830ee2f4cad73d53eaf07ec920c6f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 10:21:50 +0000 Subject: [PATCH] fix: compute row_max from RAW S values, not scaled row_max should be the max of the raw QK scores, not pre-scaled. The scale_log2 is applied during exp2 and rescaling, not stored in row_max. This fixes the double-scaling bug that broke multi-tile O rescaling. --- tests/unit/test_fmha_v3_stage_c_full.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 921420dd..b0bbda39 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -273,7 +273,7 @@ class FmhaV3StageC: tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): - row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) + row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j]) row_max_safe = row_max if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)