fix: P write with lane stride, use sRowSum

This commit is contained in:
2026-05-29 19:11:19 +00:00
parent fd6a9b00ae
commit a9a87fe7b8

View File

@@ -158,9 +158,11 @@ test_qk_softmax_kernel(
__syncthreads();
// Write P to GMEM
// Write P to GMEM — only active rows write
// Use threadIdx.x to avoid out-of-bounds
if (my_row_active) {
float inv_rs = 1.0f / my_row_sum;
for (int j = 0; j < s_k; j++) {
float inv_rs = 1.0f / sRowSum[my_row];
for (int j = lane; j < s_k; j += 32) {
out_p[my_row * s_k + j] = my_p_vals[j] * inv_rs;
}
}