fix: remove double normalization in fmha_6warp_multihead epilogue
P was already normalized in softmax step. PV = P_norm @ V gives the correct attention output. Dividing by row_sum again in the epilogue produces O = O_correct / row_sum (128x too small for uniform data).
This commit is contained in:
@@ -279,22 +279,19 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
|
||||
}
|
||||
// Normalize: O_normalized = O_unnorm / row_sum
|
||||
// LSE = log2(row_sum) + row_max * log2(e) — for multi-segment merge
|
||||
// P was NORMALIZED in softmax step. PV = P @ V is already the normalized
|
||||
// attention output. No further division by row_sum needed.
|
||||
// For single-segment decode, write O directly.
|
||||
// LSE is written for multi-segment merge (P5).
|
||||
if (lane == 0) {
|
||||
float inv_row_sum = 1.0f / row_sum;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
o_head[d] = f32_to_bf16(o_vals[d] * inv_row_sum);
|
||||
o_head[d] = f32_to_bf16(o_vals[d]);
|
||||
}
|
||||
// Write LSE if pointer is valid
|
||||
if (lse_head) {
|
||||
// LSE = log2(row_sum) + row_max / log(2)
|
||||
// Since softmax was: exp(x - row_max) / row_sum
|
||||
// log(softmax) = (x - row_max) - log(row_sum)
|
||||
// LSE = log(row_sum) + row_max (natural log)
|
||||
// Actually: the un-normalized output is sum(P*V) where P is the softmax weights
|
||||
// row_sum is the denominator of softmax.
|
||||
// LSE for the merge formula: lse = ln(row_sum) + row_max
|
||||
// LSE = ln(row_sum) + row_max (natural log)
|
||||
// This is the log of the softmax denominator, useful for
|
||||
// multi-segment merge: O = sum(exp(lse_i) * O_i) / sum(exp(lse_i))
|
||||
lse_head[0] = logf(row_sum) + row_max;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user