diff --git a/tests/unit/test_qk_softmax.cu b/tests/unit/test_qk_softmax.cu index bb0e51ad..8bd009fe 100644 --- a/tests/unit/test_qk_softmax.cu +++ b/tests/unit/test_qk_softmax.cu @@ -158,9 +158,11 @@ test_qk_softmax_kernel( __syncthreads(); // Write P to GMEM + // Write P to GMEM — only active rows write + // Use threadIdx.x to avoid out-of-bounds if (my_row_active) { - float inv_rs = 1.0f / my_row_sum; - for (int j = 0; j < s_k; j++) { + float inv_rs = 1.0f / sRowSum[my_row]; + for (int j = lane; j < s_k; j += 32) { out_p[my_row * s_k + j] = my_p_vals[j] * inv_rs; } }