test: use accumulate=false for first K-tile, skip TMEM zero
This commit is contained in:
@@ -62,13 +62,8 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Zero TMEM (tcgen05.alloc does NOT zero)
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < 128; col++) {
|
||||
tmem_store(tb + col, 0, 0, 0, 0);
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
// Note: tcgen05.alloc does NOT zero TMEM.
|
||||
// We use accumulate=false for the first K-tile, then accumulate=true.
|
||||
__syncthreads();
|
||||
|
||||
// Multi-K-tile QK GEMM
|
||||
@@ -94,7 +89,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
uint64_t dk = make_umma_desc_kmajor_none(k_addr, BLOCK_MN);
|
||||
|
||||
if (tid == 0) {
|
||||
umma_ss_f16(tb, dq, dk, idesc, true); // accumulate across K-tiles
|
||||
umma_ss_f16(tb, dq, dk, idesc, kt > 0); // first tile: no accumulate
|
||||
}
|
||||
// Fence after each K-tile MMA to ensure TMEM is updated
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
|
||||
Reference in New Issue
Block a user