test: zero O TMEM before PV GEMM

This commit is contained in:
2026-05-28 13:40:30 +00:00
parent 2885b3f2ed
commit f24bc583dc

View File

@@ -105,6 +105,16 @@ test_fmha_ts(const bf16_t* q, const bf16_t* k, const bf16_t* v,
}
__syncthreads();
// Zero O TMEM region
if (wid == 0) {
for (int n = 0; n < TMEM_O / 8; n++) {
float z0=0,z1=0,z2=0,z3=0,z4=0,z5=0,z6=0,z7=0;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb_o+n*8),"f"(z0),"f"(z1),"f"(z2),"f"(z3),"f"(z4),"f"(z5),"f"(z6),"f"(z7));
}
tmem_fence_store();
}
__syncthreads();
// STEP 3: PV GEMM via tcgen05.mma TS
// A = P (TMEM, 128 rows × SK cols), B = V (SMEM K-tiles)
// C = O (TMEM at tb_o, 128 rows × HD cols)