6-warp specialized FMHA kernel — ALL HD=16/64/128/256 PASS cos 0.999997+
Warp layout (192 threads): - Warps 0-3: Softmax + correction + epilogue - Warp 4: MMA (QK + PV GEMM) - Warp 5: Data staging (Q/K/V loads, direct GMEM for now) CTA-wide __syncthreads() sync between phases. Fix: removed spurious inv_sum normalization in epilogue (MMA output is already correctly scaled with softmax'd P). Files: fmha_6warp.cuh + test_fmha_6warp*.cu
This commit is contained in:
@@ -211,7 +211,6 @@ fmha_6warp_kernel(
|
||||
// Epilogue: TMEM → regs → normalize → BF16 → GMEM (warp 0)
|
||||
// ================================================================
|
||||
if (wid == 0) {
|
||||
float inv_sum = 1.0f / *sRowSum;
|
||||
float o_vals[HD];
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
@@ -220,7 +219,7 @@ fmha_6warp_kernel(
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * inv_sum;
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
|
||||
}
|
||||
if (lane == 0) for (int d=0;d<HD;d++) o[d] = f32_to_bf16(o_vals[d]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user