diff --git a/tests/unit/test_fmha_6warp_multihead.cu b/tests/unit/test_fmha_6warp_multihead.cu index 9a21b3f7..a13de359 100644 --- a/tests/unit/test_fmha_6warp_multihead.cu +++ b/tests/unit/test_fmha_6warp_multihead.cu @@ -116,19 +116,18 @@ static int test_mha(int n_h) { cudaFuncSetAttribute(fmha_6warp_multihead_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); } - // Copy params to device (constant mem is simpler but let's use uniform for now) - // Actually, the kernel takes FmhaParams by value, so we pass it directly fmha_6warp_multihead_kernel<<>>(params); cudaError_t launch_err = cudaGetLastError(); + cudaError_t sync_err = cudaSuccess; if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); pass = 0; goto cleanup; } - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); + sync_err = cudaDeviceSynchronize(); + if (sync_err != cudaSuccess) { + printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err)); pass = 0; goto cleanup; } @@ -238,13 +237,14 @@ static int test_mqa(int n_q, int n_kv) { fmha_6warp_multihead_kernel<<>>(params); cudaError_t launch_err = cudaGetLastError(); + cudaError_t sync_err2 = cudaSuccess; if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); pass = 0; goto cleanup; } - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); + sync_err2 = cudaDeviceSynchronize(); + if (sync_err2 != cudaSuccess) { + printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err2)); pass = 0; goto cleanup; } @@ -331,20 +331,21 @@ static int test_batched(int n_h, int batch_size) { fmha_6warp_multihead_kernel<<>>(params); cudaError_t launch_err = cudaGetLastError(); + cudaError_t sync_err3 = cudaSuccess; + int checked = 0, failed = 0; if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); pass = 0; goto cleanup; } - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); + sync_err3 = cudaDeviceSynchronize(); + if (sync_err3 != cudaSuccess) { + printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err3)); pass = 0; goto cleanup; } cudaMemcpy(h_o, d_o, total_q * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost); // Verify a sample of heads across batches - int checked = 0, failed = 0; for (int b = 0; b < batch_size; b++) { for (int h = 0; h < n_h; h++) { int idx = b * n_h + h;