fix: remove double normalization in TMA epilogue (P already normalized before PV)

This commit is contained in:
2026-05-29 19:36:41 +00:00
parent fb971781aa
commit d0a50f1f2e

View File

@@ -255,10 +255,8 @@ fmha_6warp_tma_kernel(
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
}
float row_sum = *sRowSum;
float inv_rs = 1.0f / row_sum;
if (lane == 0) for (int d=0;d<HD;d++) o[d] = f32_to_bf16(o_vals[d] * inv_rs);
if (lane == 0 && lse) lse[0] = logf(row_sum) + *sRowMax;
if (lane == 0) for (int d=0;d<HD;d++) o[d] = f32_to_bf16(o_vals[d]);
if (lane == 0 && lse) lse[0] = logf(*sRowSum) + *sRowMax;
}
__syncthreads();