Fix multi-row softmax: remove cross-lane wmax/wsum — each lane handles its own row independently
This commit is contained in:
@@ -225,15 +225,18 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
|
||||
}
|
||||
}
|
||||
|
||||
// Warp-level max reduction (only active lanes participate meaningfully)
|
||||
// Inactive lanes have row_max = -INFINITY, which is safe for fmax
|
||||
row_max = wmax(row_max);
|
||||
// Per-row max: each lane has its own row's values.
|
||||
// NO cross-lane reduction — each row's max is independent.
|
||||
// For T=1, only lane 0 has data; we still skip wmax because
|
||||
// the single row's max is already in lane 0.
|
||||
// However, for T=1 backward compat: lane 0 needs its value
|
||||
// written to sRowMax. No broadcast needed for multi-row.
|
||||
|
||||
if (lane < (unsigned)warp_num_rows) {
|
||||
sRowMax[warp_row_start + lane] = row_max;
|
||||
}
|
||||
|
||||
// Compute exp and sum per row
|
||||
// Compute exp and sum per row (independent per lane)
|
||||
float row_sum = 0.0f;
|
||||
if (lane < (unsigned)warp_num_rows) {
|
||||
for (int j = 0; j < SK_TILE; j++) {
|
||||
@@ -241,7 +244,7 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) {
|
||||
row_sum += s_vals[j];
|
||||
}
|
||||
}
|
||||
row_sum = wsum(row_sum);
|
||||
// No cross-lane sum reduction — each row's sum is in its lane
|
||||
|
||||
if (lane < (unsigned)warp_num_rows) {
|
||||
sRowSum[warp_row_start + lane] = row_sum;
|
||||
|
||||
Reference in New Issue
Block a user