From f0cb71da5c40f3b71a9c49c5e9b5f809eb340e5a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 09:54:19 +0000 Subject: [PATCH] test: TMEM 2-col with fence+sync between stores, separate wid==0 blocks --- tests/unit/test_tmem_cols.cu | 65 ++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/tests/unit/test_tmem_cols.cu b/tests/unit/test_tmem_cols.cu index 1b73d4fd..f9a446ca 100644 --- a/tests/unit/test_tmem_cols.cu +++ b/tests/unit/test_tmem_cols.cu @@ -1,5 +1,5 @@ /** - * TMEM column addressing test — simplified, matches minimal test pattern + * TMEM 2-column store test with fence between stores. */ #include @@ -21,21 +21,25 @@ __device__ void tmem_store(uint32_t c, uint32_t r0, uint32_t r1, uint32_t r2, ui __device__ void tmem_load(uint32_t c, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) { asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(c)); } +__device__ void tmem_fence() { + asm volatile("tcgen05.wait::st.sync.aligned;" :::"memory"); +} -__global__ void test_tmem_loop(float* out) { +__global__ void test_tmem_2col(float* out) { extern __shared__ char sbuf[]; uint32_t* sBase = (uint32_t*)sbuf; int lane = threadIdx.x % WARP; + int wid = threadIdx.x / WARP; - // Alloc 32 TMEM columns (same as minimal test) - if (threadIdx.x < 32) { + // Alloc 32 TMEM columns + if (wid == 0) { tmem_alloc(__cvta_generic_to_shared(sBase), 32); } __syncthreads(); uint32_t tb = *sBase; - // Store to columns 0..3 individually (no loop) - if (threadIdx.x < 32) { + // Store to columns 0 and 1 with fence between + if (wid == 0) { float v0 = (float)(lane * 4 + 0); float v1 = (float)(lane * 4 + 1); float v2 = (float)(lane * 4 + 2); @@ -46,40 +50,51 @@ __global__ void test_tmem_loop(float* out) { tmem_store(tb + 0, u0, u1, u2, u3); } - asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); + tmem_fence(); __syncthreads(); - if (threadIdx.x < 32) { + if (wid == 0) { + float v0 = (float)(lane * 4 + 100); + float v1 = (float)(lane * 4 + 101); + float v2 = (float)(lane * 4 + 102); + float v3 = (float)(lane * 4 + 103); + uint32_t u0, u1, u2, u3; + memcpy(&u0, &v0, 4); memcpy(&u1, &v1, 4); + memcpy(&u2, &v2, 4); memcpy(&u3, &v3, 4); tmem_store(tb + 1, u0, u1, u2, u3); } - } - asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); + tmem_fence(); __syncthreads(); - // Read back — 1 column only - if (threadIdx.x < 32) { - uint32_t u0, u1, u2, u3; - tmem_load(tb + 0, u0, u1, u2, u3); - float v0; memcpy(&v0, &u0, 4); - if (lane == 0) out[0] = v0; + // Read back columns 0 and 1 + if (wid == 0) { + uint32_t r0, r1, r2, r3; + tmem_load(tb + 0, r0, r1, r2, r3); + float f0; memcpy(&f0, &r0, 4); + if (lane == 0) out[0] = f0; + + tmem_load(tb + 1, r0, r1, r2, r3); + float f1; memcpy(&f1, &r0, 4); + if (lane == 0) out[1] = f1; } __syncthreads(); - if (threadIdx.x < 32) tmem_dealloc(tb, 32); + if (wid == 0) tmem_dealloc(tb, 32); } int main() { - printf("=== TMEM Loop Test ===\n"); - float* h_out = (float*)calloc(4, sizeof(float)); - float* d_out; cudaMalloc(&d_out, 4 * sizeof(float)); - cudaMemset(d_out, 0, 4 * sizeof(float)); + printf("=== TMEM 2-Column Test ===\n"); + float* h_out = (float*)calloc(2, sizeof(float)); + float* d_out; cudaMalloc(&d_out, 2 * sizeof(float)); + cudaMemset(d_out, 0, 2 * sizeof(float)); - test_tmem_loop<<<1, 64, 1024>>>(d_out); + test_tmem_2col<<<1, 64, 1024>>>(d_out); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } - cudaMemcpy(h_out, d_out, 4 * sizeof(float), cudaMemcpyDeviceToHost); - for (int i = 0; i < 4; i++) printf("col %d: %.1f\n", i, h_out[i]); - printf("Test %s\n", h_out[0] == 0.0f ? "PASSED" : "CHECK"); + cudaMemcpy(h_out, d_out, 2 * sizeof(float), cudaMemcpyDeviceToHost); + printf("col 0: %.1f (expected 0.0)\n", h_out[0]); + printf("col 1: %.1f (expected 100.0)\n", h_out[1]); + printf("Test %s\n", (fabsf(h_out[0]) < 0.1f && fabsf(h_out[1] - 100.0f) < 0.1f) ? "PASSED" : "FAILED"); cudaFree(d_out); free(h_out); return 0; }