diff --git a/tests/unit/test_mma_ts.cu b/tests/unit/test_mma_ts.cu index 9c8e120f..13288e6c 100644 --- a/tests/unit/test_mma_ts.cu +++ b/tests/unit/test_mma_ts.cu @@ -52,25 +52,29 @@ test_mma_ts(float* o_out) // Write A = all 1.0 into TMEM columns 0-15 (128 rows × 16 columns) if (wid == 0) { - for (int col = 0; col < 16; col++) { - // Each column: 128 FP32. Lane i writes positions i*4..i*4+3 - float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f; - tmem_store(tb_a + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3)); - } + // First write column 0 and verify + float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f; + tmem_store(tb_a, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3)); tmem_fence_store(); } __syncthreads(); - // Read back A to verify it was written correctly + // Read back column 0 if (wid == 0) { - float check = 0.0f; - for (int col = 0; col < 16; col++) { - uint32_t u0, u1, u2, u3; - tmem_load(tb_a + col, u0, u1, u2, u3); - tmem_fence_load(); - check += u32_to_f32(u0); + uint32_t u0, u1, u2, u3; + tmem_load(tb_a, u0, u1, u2, u3); + tmem_fence_load(); + if (lane == 0) printf("A[0,0] = %.1f (expect 1.0)\n", u32_to_f32(u0)); + } + __syncthreads(); + + // Write remaining columns + if (wid == 0) { + for (int col = 1; col < 16; col++) { + float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f; + tmem_store(tb_a + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3)); } - if (lane == 0) printf("A sum (lane 0, col 0, pos 0..3): %.1f (expect 16.0)\n", check); + tmem_fence_store(); } __syncthreads();