test: zero O TMEM before PV GEMM
This commit is contained in:
@@ -105,6 +105,16 @@ test_fmha_ts(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Zero O TMEM region
|
||||
if (wid == 0) {
|
||||
for (int n = 0; n < TMEM_O / 8; n++) {
|
||||
float z0=0,z1=0,z2=0,z3=0,z4=0,z5=0,z6=0,z7=0;
|
||||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb_o+n*8),"f"(z0),"f"(z1),"f"(z2),"f"(z3),"f"(z4),"f"(z5),"f"(z6),"f"(z7));
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// STEP 3: PV GEMM via tcgen05.mma TS
|
||||
// A = P (TMEM, 128 rows × SK cols), B = V (SMEM K-tiles)
|
||||
// C = O (TMEM at tb_o, 128 rows × HD cols)
|
||||
|
||||
Reference in New Issue
Block a user