FMHA sink: don't double-scale sink bias

The sink bias from the checkpoint is already in the scaled domain
(added to QK*scale in the reference softmax). The kernel's
running_max is max(QK*scale), so the sink should be compared
directly without multiplying by scale again.
This commit is contained in:
2026-05-31 23:12:20 +00:00
parent 4adee1207f
commit c5adbbfde6

View File

@@ -349,7 +349,10 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
// Load per-head sink bias (same for all rows in this head)
float sb = params.sink_bias[head_idx + batch_idx * params.n_h];
if (my_row_active) {
float sink_logit = sb * scale;
// sink_bias is already in the scaled domain (added to QK*scale in softmax)
// Do NOT multiply by scale again — the kernel's softmax already applies
// scale to QK values, and running_max is in the scaled domain.
float sink_logit = sb;
float old_max = sRunningMax[my_row];
float new_max = fmaxf(old_max, sink_logit);
float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;