fix: use un-normalized P for multi-tile PV (correct online softmax merge)

This commit is contained in:
2026-05-29 19:57:54 +00:00
parent 43ae3e7f98
commit d47b2bfcce

View File

@@ -201,7 +201,7 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
for (int j=0;j<kv_len;j++) s_vals[j] /= row_sum;
// DO NOT normalize P — use un-normalized exp(s-max) for correct multi-tile merge
for (int j=0;j<kv_len;j++) s_p_vals[j] = s_vals[j];
*sTileMax = row_max;
*sTileSum = row_sum;