diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index 1c59d32f..c39d4817 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -56,12 +56,14 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k, uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128); uint32_t idesc = make_idesc(128, 128); - // MMA - if (tid == 0) { + // MMA — 4 warp leaders call the instruction (Layout D requires 4 warps) + // elect_one_sync selects 1 leader per warp. With 4 warps, 4 leaders call MMA. + int elect_one = __ballot_sync(0xFFFFFFFF, lane == 0); + if (lane == 0 && wid < 4) { umma_ss_f16(tb, desc_q, desc_k, idesc, false); } __syncwarp(); - if (wid == 0 && lane == 0) tmem_fence_store(); + if (wid < 4 && lane == 0) tmem_fence_store(); __syncthreads(); // Read from TMEM using Layout D: 32x32b.x8 format