fix: remove double normalization in TMA epilogue (P already normalized before PV)
This commit is contained in:
@@ -255,10 +255,8 @@ fmha_6warp_tma_kernel(
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
|
||||
}
|
||||
float row_sum = *sRowSum;
|
||||
float inv_rs = 1.0f / row_sum;
|
||||
if (lane == 0) for (int d=0;d<HD;d++) o[d] = f32_to_bf16(o_vals[d] * inv_rs);
|
||||
if (lane == 0 && lse) lse[0] = logf(row_sum) + *sRowMax;
|
||||
if (lane == 0) for (int d=0;d<HD;d++) o[d] = f32_to_bf16(o_vals[d]);
|
||||
if (lane == 0 && lse) lse[0] = logf(*sRowSum) + *sRowMax;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user