test: TMEM cross-warp visibility with different sync strategies

This commit is contained in:
2026-05-28 22:59:31 +00:00
parent 77d190278e
commit 6b0d57074a

View File

@@ -1,9 +1,7 @@
/**
* Minimal TMEM read test: write data via 32x32b.x8 store, read it back.
* Then test if MMA output can be read by different warps.
*
* Step 1: Verify 32x32b.x8 store+load round-trip with 1 warp.
* Step 2: Write known data via MMA, read with 4 warps.
* Test: TMEM cross-warp visibility after 32x32b.x8 store.
* Try different synchronization strategies to make TMEM stores
* visible to all warps.
*/
#include <cuda_runtime.h>
@@ -26,57 +24,9 @@ __device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
:: "r"(tmem_ptr), "r"(num_cols));
}
// Test 1: Direct store + load with 1 warp
__global__ void __launch_bounds__(32)
test_tmem_store_load(float* results) {
const int lane = threadIdx.x;
const int TMEM_N = 128;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(sp, TMEM_N);
__syncwarp();
uint32_t tb = *sTmemBase;
// Store: lane l writes value (lane*10 + c) to columns 0-7
{
float vals[8];
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
uint32_t ivals[8];
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + 0),
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
}
// Read back
{
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Store first 4 lanes' values
if (lane < 4) {
for (int c = 0; c < 8; c++) {
results[lane * 8 + c] = tmp[c];
}
}
}
tmem_dealloc(tb, TMEM_N);
}
// Test 2: 4 warps read TMEM after 32x32b.x8 store from warp 0
// Strategy: warp 0 stores, then ALL warps fence + sync, then read
__global__ void __launch_bounds__(128)
test_tmem_4warp_after_store(float* results) {
test_tmem_sync(float* results, int strategy) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
@@ -92,7 +42,7 @@ test_tmem_4warp_after_store(float* results) {
__syncthreads();
uint32_t tb = *sTmemBase;
// Warp 0 stores: lane l writes (lane*10 + c) to columns 0-7
// Warp 0 stores data
if (wid == 0) {
float vals[8];
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
@@ -105,9 +55,23 @@ test_tmem_4warp_after_store(float* results) {
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
}
__syncthreads();
// All 4 warps read columns 0-7
// Strategy 0: just __syncthreads
// Strategy 1: __syncthreads + tcgen05.wait::st on all warps
// Strategy 2: __syncthreads + fence + __syncthreads
if (strategy == 0) {
__syncthreads();
} else if (strategy == 1) {
__syncthreads();
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
__syncthreads();
} else if (strategy == 2) {
__syncthreads();
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
}
// All 4 warps read
if (wid < 4) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
@@ -116,18 +80,10 @@ test_tmem_4warp_after_store(float* results) {
: "r"(tb + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Lane 0 from each warp stores
if (lane == 0) {
for (int c = 0; c < 8; c++) {
results[wid * 8 + c] = tmp[c];
}
}
// Lane 5 from each warp
if (lane == 5) {
for (int c = 0; c < 8; c++) {
results[32 + wid * 8 + c] = tmp[c];
}
}
// Lane 0: store col 0 value
if (lane == 0) results[wid] = tmp[0];
// Lane 5: store col 0 value (should be 50.0)
if (lane == 5) results[4 + wid] = tmp[0];
}
__syncthreads();
@@ -135,62 +91,42 @@ test_tmem_4warp_after_store(float* results) {
}
int main() {
printf("Step 1: 1-warp store+load round-trip\n");
printf("=====================================\n");
printf("TMEM cross-warp visibility test\n");
printf("================================\n\n");
float* d_r1;
cudaMalloc(&d_r1, 64 * sizeof(float));
cudaMemset(d_r1, 0, 64 * sizeof(float));
test_tmem_store_load<<<1, 32, 256>>>(d_r1);
for (int strat = 0; strat < 3; strat++) {
const char* names[] = {"__syncthreads only",
"__syncthreads + tcgen05.wait::st + __syncthreads",
"__syncthreads + fence.sc.gpu + __syncthreads"};
printf("Strategy %d: %s\n", strat, names[strat]);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
printf("Step 1 FAILED\n\n");
} else {
float h1[64];
cudaMemcpy(h1, d_r1, 64 * sizeof(float), cudaMemcpyDeviceToHost);
printf("Lane 0 (expect 0..7): "); for(int c=0;c<8;c++) printf("%.0f ",h1[c]); printf("\n");
printf("Lane 1 (expect 10..17): "); for(int c=0;c<8;c++) printf("%.0f ",h1[8+c]); printf("\n");
printf("Lane 2 (expect 20..27): "); for(int c=0;c<8;c++) printf("%.0f ",h1[16+c]); printf("\n");
printf("Lane 3 (expect 30..37): "); for(int c=0;c<8;c++) printf("%.0f ",h1[24+c]); printf("\n");
float* d_r;
cudaMalloc(&d_r, 32 * sizeof(float));
cudaMemset(d_r, 0, 32 * sizeof(float));
test_tmem_sync<<<1, 128, 256>>>(d_r, strat);
int ok = 1;
for(int l=0;l<4;l++) for(int c=0;c<8;c++) if(fabsf(h1[l*8+c]-(l*10+c))>0.01f) ok=0;
printf("Step 1: %s\n\n", ok ? "PASSED" : "FAILED");
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf(" CUDA ERROR: %s\n", cudaGetErrorString(err));
} else {
float h[32];
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
printf(" Lane 0, col 0 (expect 0.0 for all): ");
for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[w]);
printf("\n");
printf(" Lane 5, col 0 (expect 50.0 for all): ");
for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[4+w]);
printf("\n");
int ok = 1;
for(int w=0;w<4;w++) {
if (fabsf(h[w] - 0.0f) > 0.01f) ok = 0;
if (fabsf(h[4+w] - 50.0f) > 0.01f) ok = 0;
}
printf(" Result: %s\n\n", ok ? "ALL WARPS SEE DATA" : "SOME WARPS SEE ZEROS");
}
cudaFree(d_r);
}
printf("Step 2: 4-warps read after warp-0 store\n");
printf("========================================\n");
printf("(Do all warps see the same data, or different?)\n");
float* d_r2;
cudaMalloc(&d_r2, 128 * sizeof(float));
cudaMemset(d_r2, 0, 128 * sizeof(float));
test_tmem_4warp_after_store<<<1, 128, 256>>>(d_r2);
err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
printf("Step 2 FAILED\n\n");
} else {
float h2[128];
cudaMemcpy(h2, d_r2, 128 * sizeof(float), cudaMemcpyDeviceToHost);
printf("Lane 0 from each warp (expect all same: 0..7 if all see row 0):\n");
for(int w=0;w<4;w++) {
printf(" Warp %d: ", w);
for(int c=0;c<8;c++) printf("%.0f ", h2[w*8+c]);
printf("\n");
}
printf("Lane 5 from each warp (expect all same: 50..57 if all see row 5):\n");
for(int w=0;w<4;w++) {
printf(" Warp %d: ", w);
for(int c=0;c<8;c++) printf("%.0f ", h2[32+w*8+c]);
printf("\n");
}
}
cudaFree(d_r1);
cudaFree(d_r2);
return 0;
}