debug: add prints to multirow multitile test
This commit is contained in:
@@ -149,11 +149,17 @@ static int test_single(int T, int s_k, int n_h = 1, int batch = 1) {
|
||||
dim3 grid(1, n_h, batch);
|
||||
fmha_6warp_tma_multirow_multitile_kernel<HD><<<grid, 192, smem>>>(params);
|
||||
|
||||
cudaError_t lerr = cudaGetLastError();
|
||||
if (lerr != cudaSuccess) {
|
||||
printf(" LAUNCH ERROR: %s\n", cudaGetErrorString(lerr));
|
||||
return 1;
|
||||
}
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf(" CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
return 1;
|
||||
}
|
||||
printf(" Kernel completed OK.\n");
|
||||
|
||||
cudaMemcpy(h_o, d_o, total_heads * MAX_T * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(h_lse, d_lse, total_heads * MAX_T * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
@@ -189,6 +195,8 @@ static int test_single(int T, int s_k, int n_h = 1, int batch = 1) {
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("START: test_fmha_6warp_tma_multirow_multitile HD=%d\n", HD);
|
||||
fflush(stdout);
|
||||
int total_fail = 0;
|
||||
|
||||
printf("\n=== 6-warp TMA FMHA multi-row multi-tile HD=%d ===\n", HD);
|
||||
|
||||
Reference in New Issue
Block a user