diff --git a/tests/unit/test_tma_konly.cu b/tests/unit/test_tma_konly.cu index 01dda2cb..04a9735f 100644 --- a/tests/unit/test_tma_konly.cu +++ b/tests/unit/test_tma_konly.cu @@ -86,11 +86,12 @@ fmha_tma_konly_kernel( } __syncthreads(); - // Load K sub-tile via TMA - if (lane == 0) { + // Load K sub-tile via TMA — only warp 0, lane 0 issues + if (wid == 0 && lane == 0) { tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0); tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t)); } + // ALL threads wait for TMA completion tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; __syncthreads();