From 58be79957ddd5e935ea09e93e7f57860a62b68f0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 09:32:33 +0000 Subject: [PATCH] test: 32 TMEM cols, add MMA call with N=32, read S from TMEM --- tests/unit/test_umma_qk.cu | 47 ++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index 0097f1ed..5af429da 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -66,7 +66,7 @@ test_umma_qk_hd16( // ================================================================ if (wid == 0) { uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(smem_ptr, 128); + tmem_alloc(smem_ptr, 32); // 32 columns like minimal test } __syncthreads(); uint32_t tmem_base = *sTmemBase; @@ -78,14 +78,14 @@ test_umma_qk_hd16( } __syncthreads(); - // Zero TMEM — test 2 stores with syncwarp between them - if (wid == 0) { - tmem_store(tmem_base, 0, 0, 0, 0); - __syncwarp(); - tmem_store(tmem_base + 1, 0, 0, 0, 0); - __syncwarp(); - } - __syncthreads(); + // Zero TMEM — skip (minimal test didn't zero) + // if (wid == 0) { + // tmem_store(tmem_base, 0, 0, 0, 0); + // __syncwarp(); + // tmem_store(tmem_base + 1, 0, 0, 0, 0); + // __syncwarp(); + // } + // __syncthreads(); if (tid == 0) { s_out[140] = 2.0f; // sentinel: survived tmem_store @@ -136,13 +136,30 @@ test_umma_qk_hd16( // __syncthreads(); // ================================================================ - // Read S from TMEM — skipped (no TMEM allocated) + // Call tcgen05.mma SS — test with 32 columns (N=32) + // ================================================================ + // Rebuild idesc for N=32 + uint32_t idesc = make_idesc(128, 32); + + if (tid == 0) { + umma_ss_f16(tmem_base, desc_q, desc_k, idesc, /*accumulate=*/false); + } + __syncwarp(); + if (wid == 0 && lane == 0) { + tmem_fence_store(); + } + __syncthreads(); + + // ================================================================ + // Read S from TMEM // ================================================================ - // Just compute scalar reference and output descriptor debug info if (wid == 0) { - // Placeholder: write zeros - for (int col = 0; col < 128; col++) { - if (lane == 0) s_out[col] = 0.0f; + for (int col = 0; col < 32; col++) { + uint32_t u0, u1, u2, u3; + tmem_load(tmem_base + col, u0, u1, u2, u3); + if (lane == 0) { + s_out[col] = u32_to_f32(u0); // S[0, col] + } } } __syncthreads(); @@ -163,7 +180,7 @@ test_umma_qk_hd16( // TMEM dealloc if (wid == 0) { - tmem_dealloc(tmem_base, 128); + tmem_dealloc(tmem_base, 32); } }