From b3020c2811a3dd362c99636d52fdb562cfe9cf83 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 16:30:55 +0000 Subject: [PATCH] =?UTF-8?q?6-warp=20specialized=20FMHA=20kernel=20?= =?UTF-8?q?=E2=80=94=20ALL=20HD=3D16/64/128/256=20PASS=20cos=200.999997+?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Warp layout (192 threads): - Warps 0-3: Softmax + correction + epilogue - Warp 4: MMA (QK + PV GEMM) - Warp 5: Data staging (Q/K/V loads, direct GMEM for now) CTA-wide __syncthreads() sync between phases. Fix: removed spurious inv_sum normalization in epilogue (MMA output is already correctly scaled with softmax'd P). Files: fmha_6warp.cuh + test_fmha_6warp*.cu --- dsv4/kernels/attention/fmha_6warp.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp.cuh b/dsv4/kernels/attention/fmha_6warp.cuh index bb9c86c8..7740d84b 100644 --- a/dsv4/kernels/attention/fmha_6warp.cuh +++ b/dsv4/kernels/attention/fmha_6warp.cuh @@ -211,7 +211,6 @@ fmha_6warp_kernel( // Epilogue: TMEM → regs → normalize → BF16 → GMEM (warp 0) // ================================================================ if (wid == 0) { - float inv_sum = 1.0f / *sRowSum; float o_vals[HD]; for (int n = 0; n < HD / 8; n++) { float tmp[8]; @@ -220,7 +219,7 @@ fmha_6warp_kernel( "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * inv_sum; + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; } if (lane == 0) for (int d=0;d