fix: compute row_max from RAW S values, not scaled

row_max should be the max of the raw QK scores, not pre-scaled.
The scale_log2 is applied during exp2 and rescaling, not stored in row_max.
This fixes the double-scaling bug that broke multi-tile O rescaling.
This commit is contained in:
2026-05-22 10:21:50 +00:00
parent 0bcb5aba2b
commit e85d50dc3b

View File

@@ -273,7 +273,7 @@ class FmhaV3StageC:
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j])
row_max_safe = row_max
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)