diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh index d4bf44a3..c5bdc17a 100644 --- a/dsv4/kernels/attention/fmha_6warp_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -225,15 +225,18 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { } } - // Warp-level max reduction (only active lanes participate meaningfully) - // Inactive lanes have row_max = -INFINITY, which is safe for fmax - row_max = wmax(row_max); + // Per-row max: each lane has its own row's values. + // NO cross-lane reduction — each row's max is independent. + // For T=1, only lane 0 has data; we still skip wmax because + // the single row's max is already in lane 0. + // However, for T=1 backward compat: lane 0 needs its value + // written to sRowMax. No broadcast needed for multi-row. if (lane < (unsigned)warp_num_rows) { sRowMax[warp_row_start + lane] = row_max; } - // Compute exp and sum per row + // Compute exp and sum per row (independent per lane) float row_sum = 0.0f; if (lane < (unsigned)warp_num_rows) { for (int j = 0; j < SK_TILE; j++) { @@ -241,7 +244,7 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { row_sum += s_vals[j]; } } - row_sum = wsum(row_sum); + // No cross-lane sum reduction — each row's sum is in its lane if (lane < (unsigned)warp_num_rows) { sRowSum[warp_row_start + lane] = row_sum;