debug: add prints to multirow multitile test

This commit is contained in:
2026-05-30 04:40:06 +00:00
parent dd3e0fdfc8
commit 0ad35f8be6

View File

@@ -149,11 +149,17 @@ static int test_single(int T, int s_k, int n_h = 1, int batch = 1) {
dim3 grid(1, n_h, batch);
fmha_6warp_tma_multirow_multitile_kernel<HD><<<grid, 192, smem>>>(params);
cudaError_t lerr = cudaGetLastError();
if (lerr != cudaSuccess) {
printf(" LAUNCH ERROR: %s\n", cudaGetErrorString(lerr));
return 1;
}
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf(" CUDA ERROR: %s\n", cudaGetErrorString(err));
return 1;
}
printf(" Kernel completed OK.\n");
cudaMemcpy(h_o, d_o, total_heads * MAX_T * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost);
cudaMemcpy(h_lse, d_lse, total_heads * MAX_T * sizeof(float), cudaMemcpyDeviceToHost);
@@ -189,6 +195,8 @@ static int test_single(int T, int s_k, int n_h = 1, int batch = 1) {
}
int main() {
printf("START: test_fmha_6warp_tma_multirow_multitile HD=%d\n", HD);
fflush(stdout);
int total_fail = 0;
printf("\n=== 6-warp TMA FMHA multi-row multi-tile HD=%d ===\n", HD);