fix: compute row_max from RAW S values, not scaled
row_max should be the max of the raw QK scores, not pre-scaled. The scale_log2 is applied during exp2 and rescaling, not stored in row_max. This fixes the double-scaling bug that broke multi-tile O rescaling.
This commit is contained in:
@@ -273,7 +273,7 @@ class FmhaV3StageC:
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
|
||||
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j])
|
||||
|
||||
row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)
|
||||
|
||||
Reference in New Issue
Block a user