From 5e3c61184cfbc55354ce006a5a1ebcbb8992c9fc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 20:06:16 +0000 Subject: [PATCH] =?UTF-8?q?Fix=20multi-row=20softmax:=20remove=20cross-lan?= =?UTF-8?q?e=20wmax/wsum=20=E2=80=94=20each=20lane=20handles=20its=20own?= =?UTF-8?q?=20row=20independently?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha_6warp_multirow.cuh | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh index d4bf44a3..c5bdc17a 100644 --- a/dsv4/kernels/attention/fmha_6warp_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -225,15 +225,18 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { } } - // Warp-level max reduction (only active lanes participate meaningfully) - // Inactive lanes have row_max = -INFINITY, which is safe for fmax - row_max = wmax(row_max); + // Per-row max: each lane has its own row's values. + // NO cross-lane reduction — each row's max is independent. + // For T=1, only lane 0 has data; we still skip wmax because + // the single row's max is already in lane 0. + // However, for T=1 backward compat: lane 0 needs its value + // written to sRowMax. No broadcast needed for multi-row. if (lane < (unsigned)warp_num_rows) { sRowMax[warp_row_start + lane] = row_max; } - // Compute exp and sum per row + // Compute exp and sum per row (independent per lane) float row_sum = 0.0f; if (lane < (unsigned)warp_num_rows) { for (int j = 0; j < SK_TILE; j++) { @@ -241,7 +244,7 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { row_sum += s_vals[j]; } } - row_sum = wsum(row_sum); + // No cross-lane sum reduction — each row's sum is in its lane if (lane < (unsigned)warp_num_rows) { sRowSum[warp_row_start + lane] = row_sum;