test: use accumulate=false for first K-tile, skip TMEM zero

This commit is contained in:
2026-05-28 12:50:44 +00:00
parent e8ac2120ad
commit 435ca037cf

View File

@@ -62,13 +62,8 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
__syncthreads();
uint32_t tb = *sTmemBase;
// Zero TMEM (tcgen05.alloc does NOT zero)
if (wid == 0) {
for (int col = 0; col < 128; col++) {
tmem_store(tb + col, 0, 0, 0, 0);
}
tmem_fence_store();
}
// Note: tcgen05.alloc does NOT zero TMEM.
// We use accumulate=false for the first K-tile, then accumulate=true.
__syncthreads();
// Multi-K-tile QK GEMM
@@ -94,7 +89,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
uint64_t dk = make_umma_desc_kmajor_none(k_addr, BLOCK_MN);
if (tid == 0) {
umma_ss_f16(tb, dq, dk, idesc, true); // accumulate across K-tiles
umma_ss_f16(tb, dq, dk, idesc, kt > 0); // first tile: no accumulate
}
// Fence after each K-tile MMA to ensure TMEM is updated
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");