Fix nvcc goto-bypasses-init errors in multi-head test

This commit is contained in:
2026-05-28 19:33:04 +00:00
parent aa41cfa2e5
commit 3fd302e7a0

View File

@@ -116,19 +116,18 @@ static int test_mha(int n_h) {
cudaFuncSetAttribute(fmha_6warp_multihead_kernel<HD>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
}
// Copy params to device (constant mem is simpler but let's use uniform for now)
// Actually, the kernel takes FmhaParams by value, so we pass it directly
fmha_6warp_multihead_kernel<HD><<<dim3(1, n_h, 1), 192, smem>>>(params);
cudaError_t launch_err = cudaGetLastError();
cudaError_t sync_err = cudaSuccess;
if (launch_err != cudaSuccess) {
printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err));
pass = 0; goto cleanup;
}
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
sync_err = cudaDeviceSynchronize();
if (sync_err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err));
pass = 0; goto cleanup;
}
@@ -238,13 +237,14 @@ static int test_mqa(int n_q, int n_kv) {
fmha_6warp_multihead_kernel<HD><<<dim3(1, n_q, 1), 192, smem>>>(params);
cudaError_t launch_err = cudaGetLastError();
cudaError_t sync_err2 = cudaSuccess;
if (launch_err != cudaSuccess) {
printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err));
pass = 0; goto cleanup;
}
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
sync_err2 = cudaDeviceSynchronize();
if (sync_err2 != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err2));
pass = 0; goto cleanup;
}
@@ -331,20 +331,21 @@ static int test_batched(int n_h, int batch_size) {
fmha_6warp_multihead_kernel<HD><<<dim3(1, n_h, batch_size), 192, smem>>>(params);
cudaError_t launch_err = cudaGetLastError();
cudaError_t sync_err3 = cudaSuccess;
int checked = 0, failed = 0;
if (launch_err != cudaSuccess) {
printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err));
pass = 0; goto cleanup;
}
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
sync_err3 = cudaDeviceSynchronize();
if (sync_err3 != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err3));
pass = 0; goto cleanup;
}
cudaMemcpy(h_o, d_o, total_q * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost);
// Verify a sample of heads across batches
int checked = 0, failed = 0;
for (int b = 0; b < batch_size; b++) {
for (int h = 0; h < n_h; h++) {
int idx = b * n_h + h;