diff --git a/tests/unit/test_fmha_smem_p.cu b/tests/unit/test_fmha_smem_p.cu index 09d1ba87..49105607 100644 --- a/tests/unit/test_fmha_smem_p.cu +++ b/tests/unit/test_fmha_smem_p.cu @@ -36,6 +36,7 @@ test_fmha_smem_p(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, float* __restrict__ o_scalar, float scale) { const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + if (tid == 0) printf("Kernel started, smem=%d, wid=%d\n", threadIdx.x, wid); extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf;