test: QK GEMM + PV GEMM combined test
This commit is contained in:
@@ -40,6 +40,32 @@ test_mma_ts_copy()
|
||||
uint32_t tb = *sTmemBase;
|
||||
uint32_t tb_o = tb + 32;
|
||||
|
||||
// QK GEMM first, then PV
|
||||
bf16_t* sQ = sV + 16 * 16;
|
||||
bf16_t* sK = sQ + 128 * 16 + 4096;
|
||||
for (int i = tid; i < 128 * 16; i += 128) { sQ[i] = 0; sK[i] = 0; }
|
||||
__syncthreads();
|
||||
for (int d = tid; d < 16; d += 128) { int ck=d/8,lc=d%8; sQ[ck*16*64+lc] = f32_to_bf16(1.0f); }
|
||||
for (int i = tid; i < 128*16; i += 128) {
|
||||
int r=i/16,c=i%16; int ck=c/8,lc=c%8,tmn=r/8,lr=r%8;
|
||||
sK[ck*16*64+tmn*64+lr*8+lc] = f32_to_bf16(1.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128);
|
||||
uint32_t iqk = make_idesc(128, 128);
|
||||
if (lane == 0) umma_ss_f16(tb, dq, dk, iqk, false);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
if (wid == 0) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) printf("S[0,0] after QK: %.2f (expect 16.0)\n", tmp[0]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Now overwrite columns 0-15 with P = all 1.0 for PV
|
||||
// Write A = all 1.0 into TMEM columns 0-15
|
||||
if (wid == 0) {
|
||||
for (int n = 0; n < 16 / 8; n++) {
|
||||
|
||||
Reference in New Issue
Block a user