From 5795589abcf24e033689e898ccac594d72f0f9ce Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 09:48:27 +0000 Subject: [PATCH] test: TMEM 4 columns, individual store calls + loop load --- tests/unit/test_tmem_cols.cu | 124 +++++++++++++---------------------- 1 file changed, 47 insertions(+), 77 deletions(-) diff --git a/tests/unit/test_tmem_cols.cu b/tests/unit/test_tmem_cols.cu index 6eeb48f7..3ac39f41 100644 --- a/tests/unit/test_tmem_cols.cu +++ b/tests/unit/test_tmem_cols.cu @@ -1,5 +1,5 @@ /** - * TMEM column addressing test — is tmem_base + 1 a valid column? + * TMEM column addressing test — simplified, matches minimal test pattern */ #include @@ -7,107 +7,77 @@ #include #include -typedef unsigned short bf16_t; constexpr int WARP = 32; -__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { - asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" - :: "r"(smem_ptr), "r"(num_cols)); +__device__ void tmem_alloc(uint32_t sp, int n) { + asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" :: "r"(sp), "r"(n)); } -__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { - asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" - :: "r"(tmem_ptr), "r"(num_cols)); +__device__ void tmem_dealloc(uint32_t tp, int n) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" :: "r"(tp), "r"(n)); } -__device__ void tmem_store(uint32_t col_addr, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) { - asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};" - :: "r"(col_addr), "r"(r0), "r"(r1), "r"(r2), "r"(r3)); +__device__ void tmem_store(uint32_t c, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) { + asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};" :: "r"(c), "r"(r0), "r"(r1), "r"(r2), "r"(r3)); } -__device__ void tmem_load(uint32_t col_addr, uint32_t& r0, uint32_t& r1, uint32_t& r2, uint32_t& r3) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(col_addr)); -} -__device__ void tmem_fence_store() { - asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); +__device__ void tmem_load(uint32_t c, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(c)); } -__global__ void test_tmem_cols(float* out) { - int tid = threadIdx.x; - int lane = tid % WARP; - +__global__ void test_tmem_loop(float* out) { extern __shared__ char sbuf[]; uint32_t* sBase = (uint32_t*)sbuf; + int lane = threadIdx.x % WARP; - // Alloc 128 TMEM columns - if (tid < 32) { - uint32_t sp = __cvta_generic_to_shared(sBase); - tmem_alloc(sp, 128); + // Alloc 32 TMEM columns (same as minimal test) + if (threadIdx.x < 32) { + tmem_alloc(__cvta_generic_to_shared(sBase), 32); } __syncthreads(); - uint32_t tmem_base = *sBase; + uint32_t tb = *sBase; - // Store a unique value to each column - // Column i gets value i*1000 + lane*4 + 0..3 - if (tid < 32) { - for (int col = 0; col < 128; col++) { - float v0 = (float)(col * 1000 + lane * 4 + 0); - float v1 = (float)(col * 1000 + lane * 4 + 1); - float v2 = (float)(col * 1000 + lane * 4 + 2); - float v3 = (float)(col * 1000 + lane * 4 + 3); - uint32_t u0, u1, u2, u3; - memcpy(&u0, &v0, 4); memcpy(&u1, &v1, 4); - memcpy(&u2, &v2, 4); memcpy(&u3, &v3, 4); - tmem_store(tmem_base + col, u0, u1, u2, u3); - } - tmem_fence_store(); + // Store to columns 0..3 individually (no loop) + if (threadIdx.x < 32) { + float v0 = (float)(lane * 4 + 0); + float v1 = (float)(lane * 4 + 1); + float v2 = (float)(lane * 4 + 2); + float v3 = (float)(lane * 4 + 3); + uint32_t u0, u1, u2, u3; + memcpy(&u0, &v0, 4); memcpy(&u1, &v1, 4); + memcpy(&u2, &v2, 4); memcpy(&u3, &v3, 4); + + tmem_store(tb + 0, u0, u1, u2, u3); + tmem_store(tb + 1, u0, u1, u2, u3); + tmem_store(tb + 2, u0, u1, u2, u3); + tmem_store(tb + 3, u0, u1, u2, u3); } __syncthreads(); - // Read back first 8 columns (lane 0 only) - if (tid < 32) { - for (int col = 0; col < 8; col++) { + // Read back + if (threadIdx.x < 32) { + for (int c = 0; c < 4; c++) { uint32_t u0, u1, u2, u3; - tmem_load(tmem_base + col, u0, u1, u2, u3); - if (lane == 0) { - float v0; memcpy(&v0, &u0, 4); - out[col] = v0; - } + tmem_load(tb + c, u0, u1, u2, u3); + float v0; memcpy(&v0, &u0, 4); + if (lane == 0) out[c] = v0; } } __syncthreads(); - // Dealloc - if (tid < 32) { - tmem_dealloc(tmem_base, 128); - } + if (threadIdx.x < 32) tmem_dealloc(tb, 32); } int main() { - printf("=== TMEM Column Addressing Test ===\n"); - float* h_out = (float*)calloc(8, sizeof(float)); - float* d_out; - cudaMalloc(&d_out, 8 * sizeof(float)); - cudaMemset(d_out, 0, 8 * sizeof(float)); + printf("=== TMEM Loop Test ===\n"); + float* h_out = (float*)calloc(4, sizeof(float)); + float* d_out; cudaMalloc(&d_out, 4 * sizeof(float)); + cudaMemset(d_out, 0, 4 * sizeof(float)); - test_tmem_cols<<<1, 32, 256>>>(d_out); + test_tmem_loop<<<1, 32, 256>>>(d_out); cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); - return 1; - } + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } - cudaMemcpy(h_out, d_out, 8 * sizeof(float), cudaMemcpyDeviceToHost); - printf("Column values (expected 0, 1000, 2000, ...):\n"); - for (int i = 0; i < 8; i++) { - printf(" col %d: %.1f (expected %d.0)\n", i, h_out[i], i * 1000); - } - - int ok = 1; - for (int i = 0; i < 8; i++) { - if (fabsf(h_out[i] - i * 1000.0f) > 0.1f) ok = 0; - } - printf("Test %s\n", ok ? "PASSED" : "FAILED"); - - cudaFree(d_out); - free(h_out); - return ok ? 0 : 1; + cudaMemcpy(h_out, d_out, 4 * sizeof(float), cudaMemcpyDeviceToHost); + for (int i = 0; i < 4; i++) printf("col %d: %.1f\n", i, h_out[i]); + printf("Test %s\n", h_out[0] == 0.0f ? "PASSED" : "CHECK"); + cudaFree(d_out); free(h_out); + return 0; }