Fix multi-row softmax: remove cross-lane wmax/wsum — each lane handles its own row independently

This commit is contained in:
2026-05-28 20:06:16 +00:00
parent bf4dfd131b
commit 5e3c61184c

View File

@@ -225,15 +225,18 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
}
}
// Warp-level max reduction (only active lanes participate meaningfully)
// Inactive lanes have row_max = -INFINITY, which is safe for fmax
row_max = wmax(row_max);
// Per-row max: each lane has its own row's values.
// NO cross-lane reduction — each row's max is independent.
// For T=1, only lane 0 has data; we still skip wmax because
// the single row's max is already in lane 0.
// However, for T=1 backward compat: lane 0 needs its value
// written to sRowMax. No broadcast needed for multi-row.
if (lane < (unsigned)warp_num_rows) {
sRowMax[warp_row_start + lane] = row_max;
}
// Compute exp and sum per row
// Compute exp and sum per row (independent per lane)
float row_sum = 0.0f;
if (lane < (unsigned)warp_num_rows) {
for (int j = 0; j < SK_TILE; j++) {
@@ -241,7 +244,7 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
row_sum += s_vals[j];
}
}
row_sum = wsum(row_sum);
// No cross-lane sum reduction — each row's sum is in its lane
if (lane < (unsigned)warp_num_rows) {
sRowSum[warp_row_start + lane] = row_sum;