FMHA sink: don't double-scale sink bias
The sink bias from the checkpoint is already in the scaled domain (added to QK*scale in the reference softmax). The kernel's running_max is max(QK*scale), so the sink should be compared directly without multiplying by scale again.
This commit is contained in:
@@ -349,7 +349,10 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
// Load per-head sink bias (same for all rows in this head)
|
||||
float sb = params.sink_bias[head_idx + batch_idx * params.n_h];
|
||||
if (my_row_active) {
|
||||
float sink_logit = sb * scale;
|
||||
// sink_bias is already in the scaled domain (added to QK*scale in softmax)
|
||||
// Do NOT multiply by scale again — the kernel's softmax already applies
|
||||
// scale to QK values, and running_max is in the scaled domain.
|
||||
float sink_logit = sb;
|
||||
float old_max = sRunningMax[my_row];
|
||||
float new_max = fmaxf(old_max, sink_logit);
|
||||
float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;
|
||||
|
||||
Reference in New Issue
Block a user