6-warp specialized FMHA kernel — ALL HD=16/64/128/256 PASS cos 0.999997+

Warp layout (192 threads):
- Warps 0-3: Softmax + correction + epilogue
- Warp 4: MMA (QK + PV GEMM)
- Warp 5: Data staging (Q/K/V loads, direct GMEM for now)
CTA-wide __syncthreads() sync between phases.

Fix: removed spurious inv_sum normalization in epilogue
(MMA output is already correctly scaled with softmax'd P).

Files: fmha_6warp.cuh + test_fmha_6warp*.cu
This commit is contained in:
2026-05-28 16:30:55 +00:00
parent 2a6d72912a
commit b3020c2811

View File

@@ -211,7 +211,6 @@ fmha_6warp_kernel(
// Epilogue: TMEM → regs → normalize → BF16 → GMEM (warp 0)
// ================================================================
if (wid == 0) {
float inv_sum = 1.0f / *sRowSum;
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
@@ -220,7 +219,7 @@ fmha_6warp_kernel(
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * inv_sum;
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
}
if (lane == 0) for (int d=0;d<HD;d++) o[d] = f32_to_bf16(o_vals[d]);
}