From c5adbbfde6e0f6b804c8a82a01d391956bb64488 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 23:12:20 +0000 Subject: [PATCH] FMHA sink: don't double-scale sink bias The sink bias from the checkpoint is already in the scaled domain (added to QK*scale in the reference softmax). The kernel's running_max is max(QK*scale), so the sink should be compared directly without multiplying by scale again. --- dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh index 86e9f53b..e5fa794d 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh @@ -349,7 +349,10 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) // Load per-head sink bias (same for all rows in this head) float sb = params.sink_bias[head_idx + batch_idx * params.n_h]; if (my_row_active) { - float sink_logit = sb * scale; + // sink_bias is already in the scaled domain (added to QK*scale in softmax) + // Do NOT multiply by scale again — the kernel's softmax already applies + // scale to QK values, and running_max is in the scaled domain. + float sink_logit = sb; float old_max = sRunningMax[my_row]; float new_max = fmaxf(old_max, sink_logit); float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;