diff --git a/tests/unit/test_tmem_cols.cu b/tests/unit/test_tmem_cols.cu index bf957b4b..873f36b8 100644 --- a/tests/unit/test_tmem_cols.cu +++ b/tests/unit/test_tmem_cols.cu @@ -49,16 +49,14 @@ __global__ void test_tmem_loop(float* out) { asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); __syncthreads(); - // Read back — skip for now - // if (threadIdx.x < 32) { - // for (int c = 0; c < 4; c++) { - // uint32_t u0, u1, u2, u3; - // tmem_load(tb + c, u0, u1, u2, u3); - // float v0; memcpy(&v0, &u0, 4); - // if (lane == 0) out[c] = v0; - // } - // } - // __syncthreads(); + // Read back — 1 column only + if (threadIdx.x < 32) { + uint32_t u0, u1, u2, u3; + tmem_load(tb + 0, u0, u1, u2, u3); + float v0; memcpy(&v0, &u0, 4); + if (lane == 0) out[0] = v0; + } + __syncthreads(); if (threadIdx.x < 32) tmem_dealloc(tb, 32); }