test: step-by-step TMEM write/read debug for TS MMA

This commit is contained in:
2026-05-28 13:33:36 +00:00
parent c05cc1ac93
commit a7c81d66ba

View File

@@ -52,25 +52,29 @@ test_mma_ts(float* o_out)
// Write A = all 1.0 into TMEM columns 0-15 (128 rows × 16 columns)
if (wid == 0) {
for (int col = 0; col < 16; col++) {
// Each column: 128 FP32. Lane i writes positions i*4..i*4+3
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
tmem_store(tb_a + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
}
// First write column 0 and verify
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
tmem_store(tb_a, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
tmem_fence_store();
}
__syncthreads();
// Read back A to verify it was written correctly
// Read back column 0
if (wid == 0) {
float check = 0.0f;
for (int col = 0; col < 16; col++) {
uint32_t u0, u1, u2, u3;
tmem_load(tb_a + col, u0, u1, u2, u3);
tmem_fence_load();
check += u32_to_f32(u0);
uint32_t u0, u1, u2, u3;
tmem_load(tb_a, u0, u1, u2, u3);
tmem_fence_load();
if (lane == 0) printf("A[0,0] = %.1f (expect 1.0)\n", u32_to_f32(u0));
}
__syncthreads();
// Write remaining columns
if (wid == 0) {
for (int col = 1; col < 16; col++) {
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
tmem_store(tb_a + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
}
if (lane == 0) printf("A sum (lane 0, col 0, pos 0..3): %.1f (expect 16.0)\n", check);
tmem_fence_store();
}
__syncthreads();