diff --git a/tests/unit/test_fmha_ts_hd16.cu b/tests/unit/test_fmha_ts_hd16.cu index 889d06be..b14edf79 100644 --- a/tests/unit/test_fmha_ts_hd16.cu +++ b/tests/unit/test_fmha_ts_hd16.cu @@ -105,6 +105,16 @@ test_fmha_ts(const bf16_t* q, const bf16_t* k, const bf16_t* v, } __syncthreads(); + // Zero O TMEM region + if (wid == 0) { + for (int n = 0; n < TMEM_O / 8; n++) { + float z0=0,z1=0,z2=0,z3=0,z4=0,z5=0,z6=0,z7=0; + asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb_o+n*8),"f"(z0),"f"(z1),"f"(z2),"f"(z3),"f"(z4),"f"(z5),"f"(z6),"f"(z7)); + } + tmem_fence_store(); + } + __syncthreads(); + // STEP 3: PV GEMM via tcgen05.mma TS // A = P (TMEM, 128 rows × SK cols), B = V (SMEM K-tiles) // C = O (TMEM at tb_o, 128 rows × HD cols)