Zero TMEM manually, all K-tiles accumulate=true

This commit is contained in:
2026-05-28 14:31:22 +00:00
parent 727c509454
commit ce88cd6e9e

View File

@@ -57,16 +57,25 @@ test_pv_ss_128()
__syncthreads();
uint32_t tb = *sTmemBase;
// PV SS MMA: 8 K-tiles with accumulation
// Zero TMEM first (instead of relying on accumulate=false)
if (wid == 0) {
for (int n = 0; n < 2; n++) { // 16 cols / 8 = 2 iterations to zero O region
float z0=0,z1=0,z2=0,z3=0,z4=0,z5=0,z6=0,z7=0;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb+n*8),"f"(z0),"f"(z1),"f"(z2),"f"(z3),"f"(z4),"f"(z5),"f"(z6),"f"(z7));
}
tmem_fence_store();
}
__syncthreads();
// PV SS MMA: 8 K-tiles with accumulation (all accumulate=true)
// K-tile kt of (128,128): g_k=[2*kt, 2*kt+1], offset = kt * 2048 BF16
// P row 0 = 0.5 for first 16 positions, 0.3 for positions 16-127
{
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(BLOCK_MN, HD);
for (int kt = 0; kt < 8; kt++) {
bf16_t* sp = sP + kt * 2048;
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sp), BLOCK_MN);
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc, kt > 0);
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc, true); // ALL accumulate
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}