From ba1e81f2dcaf44f021e8603e76575a6788afc92f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 07:09:06 +0000 Subject: [PATCH] test: minimal TMEM isolation test (alloc, store, load, dealloc) --- tests/unit/test_tmem_minimal.cu | 119 ++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 tests/unit/test_tmem_minimal.cu diff --git a/tests/unit/test_tmem_minimal.cu b/tests/unit/test_tmem_minimal.cu new file mode 100644 index 00000000..ec7a9efe --- /dev/null +++ b/tests/unit/test_tmem_minimal.cu @@ -0,0 +1,119 @@ +/** + * Minimal TMEM test — isolate the hang. + * Test: alloc, store, load, dealloc. + * Run on B200 via: nvcc -std=c++20 -gencode=arch=compute_100a,code=sm_100a test_tmem_minimal.cu -o test_tmem -lcudart + */ +#include +#include +#include +#include + +// TMEM alloc: MUST be called by a FULLY ACTIVE WARP (all 32 lanes) +// num_columns: 32-512, power of 2 +// The alloc WRITES the tmem base pointer to the SMEM location pointed to by smem_ptr +__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { + asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + :: "r"(smem_ptr), "r"(num_cols)); +} + +__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" + :: "r"(tmem_ptr), "r"(num_cols)); +} + +__device__ void tmem_fence() { + asm volatile("tcgen05.fence.cta_group::1.sync.aligned;" ::: "memory"); +} + +__device__ void tmem_store(uint32_t col, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) { + asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};" + :: "r"(col), "r"(r0), "r"(r1), "r"(r2), "r"(r3)); +} + +__device__ void tmem_load(uint32_t col, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(col)); +} + +__global__ void test_tmem_alloc_only() { + extern __shared__ char sbuf[]; + uint32_t* tmem_base_ptr = (uint32_t*)sbuf; + + int wid = threadIdx.x / 32; + if (wid == 0) { + uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr); + tmem_alloc(smem_ptr, 32); + } + __syncthreads(); + + uint32_t tmem_base = *tmem_base_ptr; + if (threadIdx.x == 0) printf("alloc_only: tmem_base = %u\n", tmem_base); + + if (wid == 0) { + tmem_dealloc(tmem_base, 32); + } + if (threadIdx.x == 0) printf("alloc_only: PASS (no hang)\n"); +} + +__global__ void test_tmem_store_load() { + extern __shared__ char sbuf[]; + uint32_t* tmem_base_ptr = (uint32_t*)sbuf; + + int wid = threadIdx.x / 32; + if (wid == 0) { + uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr); + tmem_alloc(smem_ptr, 32); + } + __syncthreads(); + + uint32_t tmem_base = *tmem_base_ptr; + if (threadIdx.x == 0) printf("store_load: tmem_base = %u\n", tmem_base); + + // Warp 0, lane 0: store 4 floats to TMEM column 0 + // tcgen05.st is a warp-collective op — all 16 lanes of a half-warp participate + // Lane 0 writes, others write 0 (or their own values) + if (wid == 0) { + float f0 = 1.0f, f1 = 2.0f, f2 = 3.0f, f3 = 4.0f; + uint32_t u0, u1, u2, u3; + memcpy(&u0, &f0, 4); memcpy(&u1, &f1, 4); + memcpy(&u2, &f2, 4); memcpy(&u3, &f3, 4); + // Column address = tmem_base + column_index + // Column 0, row group 0 (rows 0-15) + uint32_t col_addr = tmem_base + 0; + tmem_store(col_addr, u0, u1, u2, u3); + } + tmem_fence(); + __syncthreads(); + + // Read back + if (wid == 0) { + uint32_t u0, u1, u2, u3; + uint32_t col_addr = tmem_base + 0; + tmem_load(col_addr, u0, u1, u2, u3); + float f0, f1, f2, f3; + memcpy(&f0, &u0, 4); memcpy(&f1, &u1, 4); + memcpy(&f2, &u2, 4); memcpy(&f3, &u3, 4); + if (threadIdx.x == 0) printf("store_load: read %f %f %f %f\n", f0, f1, f2, f3); + } + __syncthreads(); + + if (wid == 0) tmem_dealloc(tmem_base, 32); + if (threadIdx.x == 0) printf("store_load: PASS\n"); +} + +int main() { + printf("=== TMET Minimal Test ===\n\n"); + + printf("Test 1: Alloc + Dealloc only...\n"); + test_tmem_alloc_only<<<1, 64, 256>>>(); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("FAIL: %s\n", cudaGetErrorString(err)); return 1; } + + printf("\nTest 2: Store + Load...\n"); + test_tmem_store_load<<<1, 64, 1024>>>(); + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("FAIL: %s\n", cudaGetErrorString(err)); return 1; } + + printf("\nAll tests PASSED!\n"); + return 0; +}