test: step-by-step TMEM write/read debug for TS MMA
This commit is contained in:
@@ -52,25 +52,29 @@ test_mma_ts(float* o_out)
|
||||
|
||||
// Write A = all 1.0 into TMEM columns 0-15 (128 rows × 16 columns)
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < 16; col++) {
|
||||
// Each column: 128 FP32. Lane i writes positions i*4..i*4+3
|
||||
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
|
||||
tmem_store(tb_a + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
|
||||
}
|
||||
// First write column 0 and verify
|
||||
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
|
||||
tmem_store(tb_a, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Read back A to verify it was written correctly
|
||||
// Read back column 0
|
||||
if (wid == 0) {
|
||||
float check = 0.0f;
|
||||
for (int col = 0; col < 16; col++) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tb_a + col, u0, u1, u2, u3);
|
||||
tmem_fence_load();
|
||||
check += u32_to_f32(u0);
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tb_a, u0, u1, u2, u3);
|
||||
tmem_fence_load();
|
||||
if (lane == 0) printf("A[0,0] = %.1f (expect 1.0)\n", u32_to_f32(u0));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write remaining columns
|
||||
if (wid == 0) {
|
||||
for (int col = 1; col < 16; col++) {
|
||||
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
|
||||
tmem_store(tb_a + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
|
||||
}
|
||||
if (lane == 0) printf("A sum (lane 0, col 0, pos 0..3): %.1f (expect 16.0)\n", check);
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user