diff --git a/tests/unit/test_tmem_4warp_read.cu b/tests/unit/test_tmem_4warp_read.cu index d2c9fa16..ec70d33d 100644 --- a/tests/unit/test_tmem_4warp_read.cu +++ b/tests/unit/test_tmem_4warp_read.cu @@ -1,9 +1,7 @@ /** - * Minimal TMEM read test: write data via 32x32b.x8 store, read it back. - * Then test if MMA output can be read by different warps. - * - * Step 1: Verify 32x32b.x8 store+load round-trip with 1 warp. - * Step 2: Write known data via MMA, read with 4 warps. + * Test: TMEM cross-warp visibility after 32x32b.x8 store. + * Try different synchronization strategies to make TMEM stores + * visible to all warps. */ #include @@ -26,57 +24,9 @@ __device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { :: "r"(tmem_ptr), "r"(num_cols)); } -// Test 1: Direct store + load with 1 warp -__global__ void __launch_bounds__(32) -test_tmem_store_load(float* results) { - const int lane = threadIdx.x; - const int TMEM_N = 128; - - extern __shared__ char sbuf[]; - uint32_t* sTmemBase = (uint32_t*)sbuf; - - uint32_t sp = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(sp, TMEM_N); - __syncwarp(); - uint32_t tb = *sTmemBase; - - // Store: lane l writes value (lane*10 + c) to columns 0-7 - { - float vals[8]; - for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c); - uint32_t ivals[8]; - for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4); - - asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" - :: "r"(tb + 0), - "r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]), - "r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7])); - asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); - } - - // Read back - { - float tmp[8]; - asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" - : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), - "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) - : "r"(tb + 0)); - asm volatile("tcgen05.wait::ld.sync.aligned;"); - - // Store first 4 lanes' values - if (lane < 4) { - for (int c = 0; c < 8; c++) { - results[lane * 8 + c] = tmp[c]; - } - } - } - - tmem_dealloc(tb, TMEM_N); -} - -// Test 2: 4 warps read TMEM after 32x32b.x8 store from warp 0 +// Strategy: warp 0 stores, then ALL warps fence + sync, then read __global__ void __launch_bounds__(128) -test_tmem_4warp_after_store(float* results) { +test_tmem_sync(float* results, int strategy) { const int tid = threadIdx.x; const int wid = tid / 32; const int lane = tid % 32; @@ -92,7 +42,7 @@ test_tmem_4warp_after_store(float* results) { __syncthreads(); uint32_t tb = *sTmemBase; - // Warp 0 stores: lane l writes (lane*10 + c) to columns 0-7 + // Warp 0 stores data if (wid == 0) { float vals[8]; for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c); @@ -105,9 +55,23 @@ test_tmem_4warp_after_store(float* results) { "r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7])); asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); } - __syncthreads(); - // All 4 warps read columns 0-7 + // Strategy 0: just __syncthreads + // Strategy 1: __syncthreads + tcgen05.wait::st on all warps + // Strategy 2: __syncthreads + fence + __syncthreads + if (strategy == 0) { + __syncthreads(); + } else if (strategy == 1) { + __syncthreads(); + asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); + __syncthreads(); + } else if (strategy == 2) { + __syncthreads(); + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + } + + // All 4 warps read if (wid < 4) { float tmp[8]; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" @@ -116,18 +80,10 @@ test_tmem_4warp_after_store(float* results) { : "r"(tb + 0)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - // Lane 0 from each warp stores - if (lane == 0) { - for (int c = 0; c < 8; c++) { - results[wid * 8 + c] = tmp[c]; - } - } - // Lane 5 from each warp - if (lane == 5) { - for (int c = 0; c < 8; c++) { - results[32 + wid * 8 + c] = tmp[c]; - } - } + // Lane 0: store col 0 value + if (lane == 0) results[wid] = tmp[0]; + // Lane 5: store col 0 value (should be 50.0) + if (lane == 5) results[4 + wid] = tmp[0]; } __syncthreads(); @@ -135,62 +91,42 @@ test_tmem_4warp_after_store(float* results) { } int main() { - printf("Step 1: 1-warp store+load round-trip\n"); - printf("=====================================\n"); + printf("TMEM cross-warp visibility test\n"); + printf("================================\n\n"); - float* d_r1; - cudaMalloc(&d_r1, 64 * sizeof(float)); - cudaMemset(d_r1, 0, 64 * sizeof(float)); - test_tmem_store_load<<<1, 32, 256>>>(d_r1); + for (int strat = 0; strat < 3; strat++) { + const char* names[] = {"__syncthreads only", + "__syncthreads + tcgen05.wait::st + __syncthreads", + "__syncthreads + fence.sc.gpu + __syncthreads"}; + printf("Strategy %d: %s\n", strat, names[strat]); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); - printf("Step 1 FAILED\n\n"); - } else { - float h1[64]; - cudaMemcpy(h1, d_r1, 64 * sizeof(float), cudaMemcpyDeviceToHost); - printf("Lane 0 (expect 0..7): "); for(int c=0;c<8;c++) printf("%.0f ",h1[c]); printf("\n"); - printf("Lane 1 (expect 10..17): "); for(int c=0;c<8;c++) printf("%.0f ",h1[8+c]); printf("\n"); - printf("Lane 2 (expect 20..27): "); for(int c=0;c<8;c++) printf("%.0f ",h1[16+c]); printf("\n"); - printf("Lane 3 (expect 30..37): "); for(int c=0;c<8;c++) printf("%.0f ",h1[24+c]); printf("\n"); + float* d_r; + cudaMalloc(&d_r, 32 * sizeof(float)); + cudaMemset(d_r, 0, 32 * sizeof(float)); + test_tmem_sync<<<1, 128, 256>>>(d_r, strat); - int ok = 1; - for(int l=0;l<4;l++) for(int c=0;c<8;c++) if(fabsf(h1[l*8+c]-(l*10+c))>0.01f) ok=0; - printf("Step 1: %s\n\n", ok ? "PASSED" : "FAILED"); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf(" CUDA ERROR: %s\n", cudaGetErrorString(err)); + } else { + float h[32]; + cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost); + printf(" Lane 0, col 0 (expect 0.0 for all): "); + for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[w]); + printf("\n"); + printf(" Lane 5, col 0 (expect 50.0 for all): "); + for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[4+w]); + printf("\n"); + + int ok = 1; + for(int w=0;w<4;w++) { + if (fabsf(h[w] - 0.0f) > 0.01f) ok = 0; + if (fabsf(h[4+w] - 50.0f) > 0.01f) ok = 0; + } + printf(" Result: %s\n\n", ok ? "ALL WARPS SEE DATA" : "SOME WARPS SEE ZEROS"); + } + cudaFree(d_r); } - printf("Step 2: 4-warps read after warp-0 store\n"); - printf("========================================\n"); - printf("(Do all warps see the same data, or different?)\n"); - - float* d_r2; - cudaMalloc(&d_r2, 128 * sizeof(float)); - cudaMemset(d_r2, 0, 128 * sizeof(float)); - test_tmem_4warp_after_store<<<1, 128, 256>>>(d_r2); - - err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); - printf("Step 2 FAILED\n\n"); - } else { - float h2[128]; - cudaMemcpy(h2, d_r2, 128 * sizeof(float), cudaMemcpyDeviceToHost); - printf("Lane 0 from each warp (expect all same: 0..7 if all see row 0):\n"); - for(int w=0;w<4;w++) { - printf(" Warp %d: ", w); - for(int c=0;c<8;c++) printf("%.0f ", h2[w*8+c]); - printf("\n"); - } - printf("Lane 5 from each warp (expect all same: 50..57 if all see row 5):\n"); - for(int w=0;w<4;w++) { - printf(" Warp %d: ", w); - for(int c=0;c<8;c++) printf("%.0f ", h2[32+w*8+c]); - printf("\n"); - } - } - - cudaFree(d_r1); - cudaFree(d_r2); return 0; }