diff --git a/tests/unit/test_fmha_smem_p.cu b/tests/unit/test_fmha_smem_p.cu index 001836d7..87368f76 100644 --- a/tests/unit/test_fmha_smem_p.cu +++ b/tests/unit/test_fmha_smem_p.cu @@ -190,6 +190,7 @@ int main() { cudaError_t launch_err = cudaGetLastError(); if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); return 1; } + cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost);