test: 16x256b.x1 loads with uint32_t regs, matching working pattern

This commit is contained in:
2026-05-28 23:03:10 +00:00
parent fffb493b0e
commit 95003eced2

View File

@@ -1,10 +1,6 @@
/**
* Test: Can we do 16x256b.x1 LOADS multiple times without crashing?
* (The crash was on 16x256b.x1 STORES, not loads.)
*
* If loads work multiple times, we can:
* - Use 16x256b.x1 for softmax reads (128 rows, 1 column per call)
* - Use 32x32b.x8 for everything else (stores, PV, epilogue)
* Uses uint32_t registers (not float) per the working test_tmem_minimal.cu pattern.
*/
#include <cuda_runtime.h>
@@ -40,96 +36,70 @@ test_16x256b_loads(float* results) {
__syncwarp();
uint32_t tb = *sTmemBase;
// Write data via 32x32b.x8 (known working for stores)
// Write data via 32x32b.x8 (known working)
{
float vals[8];
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
uint32_t ivals[8];
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
// Write to column groups 0-3 (32 columns)
for (int n = 0; n < 4; n++) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + n * 8),
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
// Update vals for next column group
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + n * 8 + c);
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
}
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + 0),
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
}
// Now try reading with 16x256b.x1 loads
// 16x256b.x1: lane l reads 4 FP32 values (rows l*4+0..3) from 1 column
// 16x256b.x1 loads using uint32_t registers (matching test_tmem_minimal.cu)
int load_count = 0;
int pass = 1;
// Read column 0 — lane 0 should get rows 0-3, lane 1 should get rows 4-7, etc.
// Load column 0
{
float v0, v1, v2, v3;
uint32_t u0, u1, u2, u3;
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3)
: "r"(tb + 0)); // column 0
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
load_count++;
// Lane l should see values for rows l*4+0..3
// From the store: row 0 col 0 = 0.0, row 1 col 0 = 10.0, row 2 col 0 = 20.0, ...
// Wait — the 32x32b.x8 store wrote lane l's data to "row l" in TMEM.
// So row l, col 0 = l*10 + 0 = l*10
// Lane 0 reads rows 0-3: v0=row0=0, v1=row1=10, v2=row2=20, v3=row3=30
if (lane == 0) {
results[0] = v0;
results[1] = v1;
results[2] = v2;
results[3] = v3;
}
if (lane == 1) {
results[4] = v0;
results[5] = v1;
results[6] = v2;
results[7] = v3;
}
float v0, v1, v2, v3;
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
if (lane == 0) { results[0]=v0; results[1]=v1; results[2]=v2; results[3]=v3; }
if (lane == 1) { results[4]=v0; results[5]=v1; results[6]=v2; results[7]=v3; }
}
// Read column 1 (2nd 16x256b.x1 load — does it crash?)
// Load column 1 (2nd load — does it crash?)
{
float v0, v1, v2, v3;
uint32_t u0, u1, u2, u3;
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3)
: "r"(tb + 1)); // column 1
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 1));
asm volatile("tcgen05.wait::ld.sync.aligned;");
load_count++;
if (lane == 0) {
results[8] = v0;
results[9] = v1;
results[10] = v2;
results[11] = v3;
}
float v0, v1, v2, v3;
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
if (lane == 0) { results[8]=v0; results[9]=v1; results[10]=v2; results[11]=v3; }
}
// Read column 8 (8th column — more 16x256b.x1 loads)
// Load column 8 (3rd load)
{
float v0, v1, v2, v3;
uint32_t u0, u1, u2, u3;
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3)
: "r"(tb + 8));
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
load_count++;
if (lane == 0) {
results[12] = v0; // row 0, col 8 = 0*10+8 = 8.0
results[13] = v1; // row 1, col 8 = 10+8 = 18.0? No, 1*10+0=10 + 8 = 18
results[14] = v2;
results[15] = v3;
}
float v0, v1, v2, v3;
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
if (lane == 0) { results[12]=v0; results[13]=v1; results[14]=v2; results[15]=v3; }
}
// Store load count (if we get here, loads didn't crash)
if (lane == 0) results[16] = (float)load_count;
tmem_dealloc(tb, TMEM_N);
}
@@ -145,7 +115,7 @@ int main() {
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
printf("16x256b.x1 loads CRASHED (likely on 2nd or 3rd call)\n");
printf("16x256b.x1 loads CRASHED\n");
cudaFree(d_r);
return 1;
}
@@ -155,20 +125,14 @@ int main() {
printf("Load count: %d (3 loads completed = no crash)\n\n", (int)h[16]);
printf("Column 0, lane 0 (expect rows 0-3 = 0,10,20,30): %.1f %.1f %.1f %.1f\n",
h[0], h[1], h[2], h[3]);
printf("Column 0, lane 1 (expect rows 4-7 = 40,50,60,70): %.1f %.1f %.1f %.1f\n",
h[4], h[5], h[6], h[7]);
printf("Column 1, lane 0 (expect rows 0-3 = 1,11,21,31): %.1f %.1f %.1f %.1f\n",
h[8], h[9], h[10], h[11]);
printf("Column 8, lane 0 (expect row 0 col 8 = 8, row 1 col 8 = 18, etc): %.1f %.1f %.1f %.1f\n",
h[12], h[13], h[14], h[15]);
printf("Col 0, lane 0 (expect rows 0-3: 0,10,20,30): %.1f %.1f %.1f %.1f\n", h[0], h[1], h[2], h[3]);
printf("Col 0, lane 1 (expect rows 4-7: 40,50,60,70): %.1f %.1f %.1f %.1f\n", h[4], h[5], h[6], h[7]);
printf("Col 1, lane 0 (expect rows 0-3 at col 1: 1,11,21,31): %.1f %.1f %.1f %.1f\n", h[8], h[9], h[10], h[11]);
printf("Col 8, lane 0 (expect row 0 col 8 = 8, row 1 col 8 = 18, etc): %.1f %.1f %.1f %.1f\n", h[12], h[13], h[14], h[15]);
// Verify: column 0, lane 0 should give [0, 10, 20, 30]
int pass = (fabsf(h[0] - 0.0f) < 0.01f) && (fabsf(h[1] - 10.0f) < 0.01f) &&
(fabsf(h[2] - 20.0f) < 0.01f) && (fabsf(h[3] - 30.0f) < 0.01f);
printf("\nResult: %s\n", pass ? "16x256b.x1 LOADS work multiple times!" : "Data mismatch");
int pass = (fabsf(h[0]-0.0f)<0.01f) && (fabsf(h[1]-10.0f)<0.01f) &&
(fabsf(h[2]-20.0f)<0.01f) && (fabsf(h[3]-30.0f)<0.01f) && (int)h[16]==3;
printf("\nResult: %s\n", pass ? "16x256b.x1 LOADS work multiple times!" : "Data mismatch or crash");
cudaFree(d_r);
return pass ? 0 : 1;
}