diff --git a/dsv4/kernels/attention/fmha_6warp.cuh b/dsv4/kernels/attention/fmha_6warp.cuh index bb9c86c8..7740d84b 100644 --- a/dsv4/kernels/attention/fmha_6warp.cuh +++ b/dsv4/kernels/attention/fmha_6warp.cuh @@ -211,7 +211,6 @@ fmha_6warp_kernel( // Epilogue: TMEM → regs → normalize → BF16 → GMEM (warp 0) // ================================================================ if (wid == 0) { - float inv_sum = 1.0f / *sRowSum; float o_vals[HD]; for (int n = 0; n < HD / 8; n++) { float tmp[8]; @@ -220,7 +219,7 @@ fmha_6warp_kernel( "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * inv_sum; + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; } if (lane == 0) for (int d=0;d