diff --git a/tests/unit/test_mma_ts_copy.cu b/tests/unit/test_mma_ts_copy.cu index 57a33de9..6ce429ec 100644 --- a/tests/unit/test_mma_ts_copy.cu +++ b/tests/unit/test_mma_ts_copy.cu @@ -70,11 +70,36 @@ test_mma_ts_copy() if (lane == 0) for (int c=0;c<8;c++) c_vals[n*8+c] = tmp[c]; } if (lane == 0) { - printf("C[0,0..7]: "); + printf("C[0,0..7] after PV: "); for (int c=0;c<8;c++) printf("%.2f ", c_vals[c]); - printf("(expect 32.0)\n"); + printf("(expect 16.0)\n"); } } + __syncthreads(); + + // NOW do QK GEMM — does it crash after PV TS MMA? + bf16_t* sQ = sV + 16 * 16; + bf16_t* sK = sQ + 128 * 16 + 4096; + for (int i = tid; i < 128 * 16; i += 128) { sQ[i] = 0; sK[i] = 0; } + __syncthreads(); + for (int d = tid; d < 16; d += 128) { int ck=d/8,lc=d%8; sQ[ck*16*64+lc] = f32_to_bf16(1.0f); } + for (int i = tid; i < 128*16; i += 128) { + int r=i/16,c=i%16; int ck=c/8,lc=c%8,tmn=r/8,lr=r%8; + sK[ck*16*64+tmn*64+lr*8+lc] = f32_to_bf16(1.0f); + } + __syncthreads(); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128); + uint32_t iqk = make_idesc(128, 128); + if (lane == 0) umma_ss_f16(tb, dq, dk, iqk, false); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + if (wid == 0) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) printf("S[0,0] after QK (post-PV): %.2f (expect 16.0)\n", tmp[0]); + } if (wid == 0) tmem_dealloc(tb, 64); }