test: TMEM cross-warp visibility with different sync strategies
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
/**
|
||||
* Minimal TMEM read test: write data via 32x32b.x8 store, read it back.
|
||||
* Then test if MMA output can be read by different warps.
|
||||
*
|
||||
* Step 1: Verify 32x32b.x8 store+load round-trip with 1 warp.
|
||||
* Step 2: Write known data via MMA, read with 4 warps.
|
||||
* Test: TMEM cross-warp visibility after 32x32b.x8 store.
|
||||
* Try different synchronization strategies to make TMEM stores
|
||||
* visible to all warps.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -26,57 +24,9 @@ __device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
||||
:: "r"(tmem_ptr), "r"(num_cols));
|
||||
}
|
||||
|
||||
// Test 1: Direct store + load with 1 warp
|
||||
__global__ void __launch_bounds__(32)
|
||||
test_tmem_store_load(float* results) {
|
||||
const int lane = threadIdx.x;
|
||||
const int TMEM_N = 128;
|
||||
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
|
||||
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(sp, TMEM_N);
|
||||
__syncwarp();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Store: lane l writes value (lane*10 + c) to columns 0-7
|
||||
{
|
||||
float vals[8];
|
||||
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
|
||||
uint32_t ivals[8];
|
||||
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
|
||||
|
||||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
|
||||
:: "r"(tb + 0),
|
||||
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
|
||||
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
||||
}
|
||||
|
||||
// Read back
|
||||
{
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + 0));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
|
||||
// Store first 4 lanes' values
|
||||
if (lane < 4) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
results[lane * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
// Test 2: 4 warps read TMEM after 32x32b.x8 store from warp 0
|
||||
// Strategy: warp 0 stores, then ALL warps fence + sync, then read
|
||||
__global__ void __launch_bounds__(128)
|
||||
test_tmem_4warp_after_store(float* results) {
|
||||
test_tmem_sync(float* results, int strategy) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
@@ -92,7 +42,7 @@ test_tmem_4warp_after_store(float* results) {
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Warp 0 stores: lane l writes (lane*10 + c) to columns 0-7
|
||||
// Warp 0 stores data
|
||||
if (wid == 0) {
|
||||
float vals[8];
|
||||
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
|
||||
@@ -105,9 +55,23 @@ test_tmem_4warp_after_store(float* results) {
|
||||
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// All 4 warps read columns 0-7
|
||||
// Strategy 0: just __syncthreads
|
||||
// Strategy 1: __syncthreads + tcgen05.wait::st on all warps
|
||||
// Strategy 2: __syncthreads + fence + __syncthreads
|
||||
if (strategy == 0) {
|
||||
__syncthreads();
|
||||
} else if (strategy == 1) {
|
||||
__syncthreads();
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
||||
__syncthreads();
|
||||
} else if (strategy == 2) {
|
||||
__syncthreads();
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// All 4 warps read
|
||||
if (wid < 4) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
@@ -116,18 +80,10 @@ test_tmem_4warp_after_store(float* results) {
|
||||
: "r"(tb + 0));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
|
||||
// Lane 0 from each warp stores
|
||||
if (lane == 0) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
results[wid * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
// Lane 5 from each warp
|
||||
if (lane == 5) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
results[32 + wid * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
// Lane 0: store col 0 value
|
||||
if (lane == 0) results[wid] = tmp[0];
|
||||
// Lane 5: store col 0 value (should be 50.0)
|
||||
if (lane == 5) results[4 + wid] = tmp[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -135,62 +91,42 @@ test_tmem_4warp_after_store(float* results) {
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("Step 1: 1-warp store+load round-trip\n");
|
||||
printf("=====================================\n");
|
||||
printf("TMEM cross-warp visibility test\n");
|
||||
printf("================================\n\n");
|
||||
|
||||
float* d_r1;
|
||||
cudaMalloc(&d_r1, 64 * sizeof(float));
|
||||
cudaMemset(d_r1, 0, 64 * sizeof(float));
|
||||
test_tmem_store_load<<<1, 32, 256>>>(d_r1);
|
||||
for (int strat = 0; strat < 3; strat++) {
|
||||
const char* names[] = {"__syncthreads only",
|
||||
"__syncthreads + tcgen05.wait::st + __syncthreads",
|
||||
"__syncthreads + fence.sc.gpu + __syncthreads"};
|
||||
printf("Strategy %d: %s\n", strat, names[strat]);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
printf("Step 1 FAILED\n\n");
|
||||
} else {
|
||||
float h1[64];
|
||||
cudaMemcpy(h1, d_r1, 64 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
printf("Lane 0 (expect 0..7): "); for(int c=0;c<8;c++) printf("%.0f ",h1[c]); printf("\n");
|
||||
printf("Lane 1 (expect 10..17): "); for(int c=0;c<8;c++) printf("%.0f ",h1[8+c]); printf("\n");
|
||||
printf("Lane 2 (expect 20..27): "); for(int c=0;c<8;c++) printf("%.0f ",h1[16+c]); printf("\n");
|
||||
printf("Lane 3 (expect 30..37): "); for(int c=0;c<8;c++) printf("%.0f ",h1[24+c]); printf("\n");
|
||||
float* d_r;
|
||||
cudaMalloc(&d_r, 32 * sizeof(float));
|
||||
cudaMemset(d_r, 0, 32 * sizeof(float));
|
||||
test_tmem_sync<<<1, 128, 256>>>(d_r, strat);
|
||||
|
||||
int ok = 1;
|
||||
for(int l=0;l<4;l++) for(int c=0;c<8;c++) if(fabsf(h1[l*8+c]-(l*10+c))>0.01f) ok=0;
|
||||
printf("Step 1: %s\n\n", ok ? "PASSED" : "FAILED");
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf(" CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
} else {
|
||||
float h[32];
|
||||
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
printf(" Lane 0, col 0 (expect 0.0 for all): ");
|
||||
for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[w]);
|
||||
printf("\n");
|
||||
printf(" Lane 5, col 0 (expect 50.0 for all): ");
|
||||
for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[4+w]);
|
||||
printf("\n");
|
||||
|
||||
int ok = 1;
|
||||
for(int w=0;w<4;w++) {
|
||||
if (fabsf(h[w] - 0.0f) > 0.01f) ok = 0;
|
||||
if (fabsf(h[4+w] - 50.0f) > 0.01f) ok = 0;
|
||||
}
|
||||
printf(" Result: %s\n\n", ok ? "ALL WARPS SEE DATA" : "SOME WARPS SEE ZEROS");
|
||||
}
|
||||
cudaFree(d_r);
|
||||
}
|
||||
|
||||
printf("Step 2: 4-warps read after warp-0 store\n");
|
||||
printf("========================================\n");
|
||||
printf("(Do all warps see the same data, or different?)\n");
|
||||
|
||||
float* d_r2;
|
||||
cudaMalloc(&d_r2, 128 * sizeof(float));
|
||||
cudaMemset(d_r2, 0, 128 * sizeof(float));
|
||||
test_tmem_4warp_after_store<<<1, 128, 256>>>(d_r2);
|
||||
|
||||
err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
printf("Step 2 FAILED\n\n");
|
||||
} else {
|
||||
float h2[128];
|
||||
cudaMemcpy(h2, d_r2, 128 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
printf("Lane 0 from each warp (expect all same: 0..7 if all see row 0):\n");
|
||||
for(int w=0;w<4;w++) {
|
||||
printf(" Warp %d: ", w);
|
||||
for(int c=0;c<8;c++) printf("%.0f ", h2[w*8+c]);
|
||||
printf("\n");
|
||||
}
|
||||
printf("Lane 5 from each warp (expect all same: 50..57 if all see row 5):\n");
|
||||
for(int w=0;w<4;w++) {
|
||||
printf(" Warp %d: ", w);
|
||||
for(int c=0;c<8;c++) printf("%.0f ", h2[32+w*8+c]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
cudaFree(d_r1);
|
||||
cudaFree(d_r2);
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user