Fix nvcc goto-bypasses-init errors in multi-head test
This commit is contained in:
@@ -116,19 +116,18 @@ static int test_mha(int n_h) {
|
||||
cudaFuncSetAttribute(fmha_6warp_multihead_kernel<HD>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||||
}
|
||||
|
||||
// Copy params to device (constant mem is simpler but let's use uniform for now)
|
||||
// Actually, the kernel takes FmhaParams by value, so we pass it directly
|
||||
fmha_6warp_multihead_kernel<HD><<<dim3(1, n_h, 1), 192, smem>>>(params);
|
||||
|
||||
cudaError_t launch_err = cudaGetLastError();
|
||||
cudaError_t sync_err = cudaSuccess;
|
||||
if (launch_err != cudaSuccess) {
|
||||
printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err));
|
||||
pass = 0; goto cleanup;
|
||||
}
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
sync_err = cudaDeviceSynchronize();
|
||||
if (sync_err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err));
|
||||
pass = 0; goto cleanup;
|
||||
}
|
||||
|
||||
@@ -238,13 +237,14 @@ static int test_mqa(int n_q, int n_kv) {
|
||||
fmha_6warp_multihead_kernel<HD><<<dim3(1, n_q, 1), 192, smem>>>(params);
|
||||
|
||||
cudaError_t launch_err = cudaGetLastError();
|
||||
cudaError_t sync_err2 = cudaSuccess;
|
||||
if (launch_err != cudaSuccess) {
|
||||
printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err));
|
||||
pass = 0; goto cleanup;
|
||||
}
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
sync_err2 = cudaDeviceSynchronize();
|
||||
if (sync_err2 != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err2));
|
||||
pass = 0; goto cleanup;
|
||||
}
|
||||
|
||||
@@ -331,20 +331,21 @@ static int test_batched(int n_h, int batch_size) {
|
||||
fmha_6warp_multihead_kernel<HD><<<dim3(1, n_h, batch_size), 192, smem>>>(params);
|
||||
|
||||
cudaError_t launch_err = cudaGetLastError();
|
||||
cudaError_t sync_err3 = cudaSuccess;
|
||||
int checked = 0, failed = 0;
|
||||
if (launch_err != cudaSuccess) {
|
||||
printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err));
|
||||
pass = 0; goto cleanup;
|
||||
}
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
sync_err3 = cudaDeviceSynchronize();
|
||||
if (sync_err3 != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err3));
|
||||
pass = 0; goto cleanup;
|
||||
}
|
||||
|
||||
cudaMemcpy(h_o, d_o, total_q * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Verify a sample of heads across batches
|
||||
int checked = 0, failed = 0;
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
for (int h = 0; h < n_h; h++) {
|
||||
int idx = b * n_h + h;
|
||||
|
||||
Reference in New Issue
Block a user