From 10915c4e70604e4d60eabe6305c49824bcd0d554 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 08:26:20 +0000 Subject: [PATCH] fix: remove double normalization in fmha_6warp_multihead epilogue P was already normalized in softmax step. PV = P_norm @ V gives the correct attention output. Dividing by row_sum again in the epilogue produces O = O_correct / row_sum (128x too small for uniform data). --- .../attention/fmha_6warp_multihead.cuh | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_multihead.cuh b/dsv4/kernels/attention/fmha_6warp_multihead.cuh index 70034366..a8a8c87a 100644 --- a/dsv4/kernels/attention/fmha_6warp_multihead.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multihead.cuh @@ -279,22 +279,19 @@ fmha_6warp_multihead_kernel(FmhaParams params) { asm volatile("tcgen05.wait::ld.sync.aligned;"); if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; } - // Normalize: O_normalized = O_unnorm / row_sum - // LSE = log2(row_sum) + row_max * log2(e) — for multi-segment merge + // P was NORMALIZED in softmax step. PV = P @ V is already the normalized + // attention output. No further division by row_sum needed. + // For single-segment decode, write O directly. + // LSE is written for multi-segment merge (P5). if (lane == 0) { - float inv_row_sum = 1.0f / row_sum; for (int d = 0; d < HD; d++) { - o_head[d] = f32_to_bf16(o_vals[d] * inv_row_sum); + o_head[d] = f32_to_bf16(o_vals[d]); } // Write LSE if pointer is valid if (lse_head) { - // LSE = log2(row_sum) + row_max / log(2) - // Since softmax was: exp(x - row_max) / row_sum - // log(softmax) = (x - row_max) - log(row_sum) - // LSE = log(row_sum) + row_max (natural log) - // Actually: the un-normalized output is sum(P*V) where P is the softmax weights - // row_sum is the denominator of softmax. - // LSE for the merge formula: lse = ln(row_sum) + row_max + // LSE = ln(row_sum) + row_max (natural log) + // This is the log of the softmax denominator, useful for + // multi-segment merge: O = sum(exp(lse_i) * O_i) / sum(exp(lse_i)) lse_head[0] = logf(row_sum) + row_max; } }