fix: remove double normalization in fmha_6warp_multihead epilogue

P was already normalized in softmax step. PV = P_norm @ V gives the
correct attention output. Dividing by row_sum again in the epilogue
produces O = O_correct / row_sum (128x too small for uniform data).
This commit is contained in:
2026-05-30 08:26:20 +00:00
parent cfac224b59
commit 10915c4e70

View File

@@ -279,22 +279,19 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
}
// Normalize: O_normalized = O_unnorm / row_sum
// LSE = log2(row_sum) + row_max * log2(e) — for multi-segment merge
// P was NORMALIZED in softmax step. PV = P @ V is already the normalized
// attention output. No further division by row_sum needed.
// For single-segment decode, write O directly.
// LSE is written for multi-segment merge (P5).
if (lane == 0) {
float inv_row_sum = 1.0f / row_sum;
for (int d = 0; d < HD; d++) {
o_head[d] = f32_to_bf16(o_vals[d] * inv_row_sum);
o_head[d] = f32_to_bf16(o_vals[d]);
}
// Write LSE if pointer is valid
if (lse_head) {
// LSE = log2(row_sum) + row_max / log(2)
// Since softmax was: exp(x - row_max) / row_sum
// log(softmax) = (x - row_max) - log(row_sum)
// LSE = log(row_sum) + row_max (natural log)
// Actually: the un-normalized output is sum(P*V) where P is the softmax weights
// row_sum is the denominator of softmax.
// LSE for the merge formula: lse = ln(row_sum) + row_max
// LSE = ln(row_sum) + row_max (natural log)
// This is the log of the softmax denominator, useful for
// multi-segment merge: O = sum(exp(lse_i) * O_i) / sum(exp(lse_i))
lse_head[0] = logf(row_sum) + row_max;
}
}