diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 580a9da7..c4ba6ed1 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -67,14 +67,10 @@ test_umma_hd64(const bf16_t* q, const bf16_t* k, uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128); uint32_t idesc = make_idesc(128, 128); - // MMA + // MMA — always accumulate (TMEM starts at 0 after alloc) if (lane == 0) { - umma_ss_f16(tb, dq, dk, idesc, kt > 0); + umma_ss_f16(tb, dq, dk, idesc, true); // Always accumulate } - __syncwarp(); // Ensure MMA is issued - asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); - __syncthreads(); // Wait for all warps - __syncthreads(); // Extra barrier for safety } // Read TMEM