fix: P write with lane stride, use sRowSum
This commit is contained in:
@@ -158,9 +158,11 @@ test_qk_softmax_kernel(
|
||||
__syncthreads();
|
||||
|
||||
// Write P to GMEM
|
||||
// Write P to GMEM — only active rows write
|
||||
// Use threadIdx.x to avoid out-of-bounds
|
||||
if (my_row_active) {
|
||||
float inv_rs = 1.0f / my_row_sum;
|
||||
for (int j = 0; j < s_k; j++) {
|
||||
float inv_rs = 1.0f / sRowSum[my_row];
|
||||
for (int j = lane; j < s_k; j += 32) {
|
||||
out_p[my_row * s_k + j] = my_p_vals[j] * inv_rs;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user