From ce88cd6e9e0ec42cf6401e75a1dfc2c549540d7d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 14:31:22 +0000 Subject: [PATCH] Zero TMEM manually, all K-tiles accumulate=true --- tests/unit/test_pv_ss_128.cu | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_pv_ss_128.cu b/tests/unit/test_pv_ss_128.cu index 8bc4fc19..c7ed1566 100644 --- a/tests/unit/test_pv_ss_128.cu +++ b/tests/unit/test_pv_ss_128.cu @@ -57,16 +57,25 @@ test_pv_ss_128() __syncthreads(); uint32_t tb = *sTmemBase; - // PV SS MMA: 8 K-tiles with accumulation + // Zero TMEM first (instead of relying on accumulate=false) + if (wid == 0) { + for (int n = 0; n < 2; n++) { // 16 cols / 8 = 2 iterations to zero O region + float z0=0,z1=0,z2=0,z3=0,z4=0,z5=0,z6=0,z7=0; + asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb+n*8),"f"(z0),"f"(z1),"f"(z2),"f"(z3),"f"(z4),"f"(z5),"f"(z6),"f"(z7)); + } + tmem_fence_store(); + } + __syncthreads(); + + // PV SS MMA: 8 K-tiles with accumulation (all accumulate=true) // K-tile kt of (128,128): g_k=[2*kt, 2*kt+1], offset = kt * 2048 BF16 - // P row 0 = 0.5 for first 16 positions, 0.3 for positions 16-127 { uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); uint32_t idesc = make_idesc(BLOCK_MN, HD); for (int kt = 0; kt < 8; kt++) { bf16_t* sp = sP + kt * 2048; uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sp), BLOCK_MN); - if (tid == 0) umma_ss_f16(tb, dp, dv, idesc, kt > 0); + if (tid == 0) umma_ss_f16(tb, dp, dv, idesc, true); // ALL accumulate asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); }