diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh index 86e9f53b..e5fa794d 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh @@ -349,7 +349,10 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) // Load per-head sink bias (same for all rows in this head) float sb = params.sink_bias[head_idx + batch_idx * params.n_h]; if (my_row_active) { - float sink_logit = sb * scale; + // sink_bias is already in the scaled domain (added to QK*scale in softmax) + // Do NOT multiply by scale again — the kernel's softmax already applies + // scale to QK values, and running_max is in the scaled domain. + float sink_logit = sb; float old_max = sRunningMax[my_row]; float new_max = fmaxf(old_max, sink_logit); float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;