diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 22b333e5..32b15c46 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -62,13 +62,8 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, __syncthreads(); uint32_t tb = *sTmemBase; - // Zero TMEM (tcgen05.alloc does NOT zero) - if (wid == 0) { - for (int col = 0; col < 128; col++) { - tmem_store(tb + col, 0, 0, 0, 0); - } - tmem_fence_store(); - } + // Note: tcgen05.alloc does NOT zero TMEM. + // We use accumulate=false for the first K-tile, then accumulate=true. __syncthreads(); // Multi-K-tile QK GEMM @@ -94,7 +89,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, uint64_t dk = make_umma_desc_kmajor_none(k_addr, BLOCK_MN); if (tid == 0) { - umma_ss_f16(tb, dq, dk, idesc, true); // accumulate across K-tiles + umma_ss_f16(tb, dq, dk, idesc, kt > 0); // first tile: no accumulate } // Fence after each K-tile MMA to ensure TMEM is updated asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");