test: 32 TMEM cols, add MMA call with N=32, read S from TMEM

This commit is contained in:
2026-05-28 09:32:33 +00:00
parent 22fb861447
commit 58be79957d

View File

@@ -66,7 +66,7 @@ test_umma_qk_hd16(
// ================================================================
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(smem_ptr, 128);
tmem_alloc(smem_ptr, 32); // 32 columns like minimal test
}
__syncthreads();
uint32_t tmem_base = *sTmemBase;
@@ -78,14 +78,14 @@ test_umma_qk_hd16(
}
__syncthreads();
// Zero TMEM — test 2 stores with syncwarp between them
if (wid == 0) {
tmem_store(tmem_base, 0, 0, 0, 0);
__syncwarp();
tmem_store(tmem_base + 1, 0, 0, 0, 0);
__syncwarp();
}
__syncthreads();
// Zero TMEM — skip (minimal test didn't zero)
// if (wid == 0) {
// tmem_store(tmem_base, 0, 0, 0, 0);
// __syncwarp();
// tmem_store(tmem_base + 1, 0, 0, 0, 0);
// __syncwarp();
// }
// __syncthreads();
if (tid == 0) {
s_out[140] = 2.0f; // sentinel: survived tmem_store
@@ -136,13 +136,30 @@ test_umma_qk_hd16(
// __syncthreads();
// ================================================================
// Read S from TMEM — skipped (no TMEM allocated)
// Call tcgen05.mma SS — test with 32 columns (N=32)
// ================================================================
// Rebuild idesc for N=32
uint32_t idesc = make_idesc(128, 32);
if (tid == 0) {
umma_ss_f16(tmem_base, desc_q, desc_k, idesc, /*accumulate=*/false);
}
__syncwarp();
if (wid == 0 && lane == 0) {
tmem_fence_store();
}
__syncthreads();
// ================================================================
// Read S from TMEM
// ================================================================
// Just compute scalar reference and output descriptor debug info
if (wid == 0) {
// Placeholder: write zeros
for (int col = 0; col < 128; col++) {
if (lane == 0) s_out[col] = 0.0f;
for (int col = 0; col < 32; col++) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + col, u0, u1, u2, u3);
if (lane == 0) {
s_out[col] = u32_to_f32(u0); // S[0, col]
}
}
}
__syncthreads();
@@ -163,7 +180,7 @@ test_umma_qk_hd16(
// TMEM dealloc
if (wid == 0) {
tmem_dealloc(tmem_base, 128);
tmem_dealloc(tmem_base, 32);
}
}