fix: use un-normalized P for multi-tile PV (correct online softmax merge)
This commit is contained in:
@@ -201,7 +201,7 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) {
|
||||
s_vals[j] = expf(s_vals[j] - row_max);
|
||||
row_sum += s_vals[j];
|
||||
}
|
||||
for (int j=0;j<kv_len;j++) s_vals[j] /= row_sum;
|
||||
// DO NOT normalize P — use un-normalized exp(s-max) for correct multi-tile merge
|
||||
for (int j=0;j<kv_len;j++) s_p_vals[j] = s_vals[j];
|
||||
*sTileMax = row_max;
|
||||
*sTileSum = row_sum;
|
||||
|
||||
Reference in New Issue
Block a user