diff --git a/tests/unit/test_tmem_lane_mapping.cu b/tests/unit/test_tmem_lane_mapping.cu new file mode 100644 index 00000000..3789df70 --- /dev/null +++ b/tests/unit/test_tmem_lane_mapping.cu @@ -0,0 +1,247 @@ +/** + * TMEM round-trip test — verify SMEM → TMEM → regs → GMEM pipeline. + * Write known FP32 values to TMEM via warp-collective store, + * read them back via warp-collective load, compare with expected. + */ +#include +#include +#include +#include + +// Minimal TMEM ops (same as fmha_common.cuh) +typedef unsigned short bf16_t; + +__device__ __forceinline__ uint32_t f32_to_u32(float f) { uint32_t u; memcpy(&u,&f,4); return u; } +__device__ __forceinline__ float u32_to_f32(uint32_t u) { float f; memcpy(&f,&u,4); return f; } + +constexpr int WARP = 32; + +__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { + asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + :: "r"(smem_ptr), "r"(num_cols)); +} +__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" + :: "r"(tmem_ptr), "r"(num_cols)); +} +__device__ void tmem_fence_store() { + asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); +} +__device__ void tmem_load(uint32_t col, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(col)); +} +__device__ void tmem_store(uint32_t col, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) { + asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};" + :: "r"(col), "r"(r0), "r"(r1), "r"(r2), "r"(r3)); +} + +/** + * Test 1: Understand the lane-to-position mapping. + * + * Each lane writes a unique identifier to its 4 registers. + * Lane i writes: (i*4+0, i*4+1, i*4+2, i*4+3) as FP32. + * Then reads back and prints. + * + * This tells us which lane's data ends up where in the TMEM column. + */ +__global__ void test_lane_mapping() { + extern __shared__ char sbuf[]; + uint32_t* tmem_base_ptr = (uint32_t*)sbuf; + float* sData = (float*)(sbuf + 4); // 16 floats for our test + uint32_t* sResults = (uint32_t*)(sbuf + 4 + 16*4); // read-back results + + int lane = threadIdx.x % 32; + int wid = threadIdx.x / 32; + + // Alloc TMEM + if (wid == 0) { + uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr); + tmem_alloc(smem_ptr, 32); + } + __syncthreads(); + uint32_t tmem_base = *tmem_base_ptr; + + // Write to TMEM: lane i writes (i*4, i*4+1, i*4+2, i*4+3) to column 0 + // ALL 32 lanes participate (warp-collective) + if (wid == 0) { + float v0 = (float)(lane * 4 + 0); + float v1 = (float)(lane * 4 + 1); + float v2 = (float)(lane * 4 + 2); + float v3 = (float)(lane * 4 + 3); + tmem_store(tmem_base + 0, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3)); + tmem_fence_store(); + } + __syncthreads(); + + // Read back from TMEM column 0 + if (wid == 0) { + uint32_t u0, u1, u2, u3; + tmem_load(tmem_base + 0, u0, u1, u2, u3); + // Each lane stores its results to SMEM + sResults[lane * 4 + 0] = u0; + sResults[lane * 4 + 1] = u1; + sResults[lane * 4 + 2] = u2; + sResults[lane * 4 + 3] = u3; + } + __syncthreads(); + + // Print results from thread 0 (only print first few lanes) + if (threadIdx.x == 0) { + printf("Lane mapping test (column 0):\n"); + for (int i = 0; i < 4; i++) { + float r0 = u32_to_f32(sResults[i*4+0]); + float r1 = u32_to_f32(sResults[i*4+1]); + float r2 = u32_to_f32(sResults[i*4+2]); + float r3 = u32_to_f32(sResults[i*4+3]); + printf(" Lane %2d wrote (%d,%d,%d,%d), read back (%.0f,%.0f,%.0f,%.0f)\n", + i, i*4+0, i*4+1, i*4+2, i*4+3, r0, r1, r2, r3); + } + } + + if (wid == 0) tmem_dealloc(tmem_base, 32); +} + +/** + * Test 2: Write HD=4 values (1 column) to TMEM from SMEM, + * read back, compare. This is the minimal FMHA epilogue test. + */ +__global__ void test_smem_to_tmem_roundtrip() { + extern __shared__ char sbuf[]; + uint32_t* tmem_base_ptr = (uint32_t*)sbuf; + float* sData = (float*)(sbuf + 4); + float* sResult = (float*)(sbuf + 4 + 4*4); + + int lane = threadIdx.x % 32; + int wid = threadIdx.x / 32; + + // Init SMEM data: [1.0, 2.0, 3.0, 4.0] + [0, 0, 0, 0] (padding to 8) + if (threadIdx.x == 0) { + sData[0] = 1.0f; sData[1] = 2.0f; sData[2] = 3.0f; sData[3] = 4.0f; + } + __syncthreads(); + + // Alloc TMEM (32 cols minimum) + if (wid == 0) { + uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr); + tmem_alloc(smem_ptr, 32); + } + __syncthreads(); + uint32_t tmem_base = *tmem_base_ptr; + + // Write sData[0..3] to TMEM column 0, rows 0-3 + // Lane 0 handles these 4 values, all lanes participate + if (wid == 0) { + // Lane 0: write sData[0..3]. Other lanes write 0. + uint32_t u0 = (lane == 0) ? f32_to_u32(sData[0]) : 0; + uint32_t u1 = (lane == 0) ? f32_to_u32(sData[1]) : 0; + uint32_t u2 = (lane == 0) ? f32_to_u32(sData[2]) : 0; + uint32_t u3 = (lane == 0) ? f32_to_u32(sData[3]) : 0; + tmem_store(tmem_base + 0, u0, u1, u2, u3); + tmem_fence_store(); + } + __syncthreads(); + + // Read back + if (wid == 0) { + uint32_t u0, u1, u2, u3; + tmem_load(tmem_base + 0, u0, u1, u2, u3); + // Lane 0 reads and stores to SMEM + if (lane == 0) { + sResult[0] = u32_to_f32(u0); + sResult[1] = u32_to_f32(u1); + sResult[2] = u32_to_f32(u2); + sResult[3] = u32_to_f32(u3); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + printf("SMEM→TMEM round-trip: wrote [1,2,3,4], read [%.1f,%.1f,%.1f,%.1f]\n", + sResult[0], sResult[1], sResult[2], sResult[3]); + int ok = (fabsf(sResult[0]-1.0f)<0.01f && fabsf(sResult[1]-2.0f)<0.01f && + fabsf(sResult[2]-3.0f)<0.01f && fabsf(sResult[3]-4.0f)<0.01f); + printf(" %s\n", ok ? "✅ PASS" : "❌ FAIL"); + } + + if (wid == 0) tmem_dealloc(tmem_base, 32); +} + +/** + * Test 3: Multi-lane write — understand how lanes map to data positions. + * We write 32 different columns (one per lane), each lane writes its lane_id + * to column = tmem_base + lane. Then we read back column 0 and see what + * lane 0 gets. + */ +__global__ void test_per_lane_columns() { + extern __shared__ char sbuf[]; + uint32_t* tmem_base_ptr = (uint32_t*)sbuf; + float* sResult = (float*)(sbuf + 4); + + int lane = threadIdx.x % 32; + int wid = threadIdx.x / 32; + + if (wid == 0) { + uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr); + tmem_alloc(smem_ptr, 32); + } + __syncthreads(); + uint32_t tmem_base = *tmem_base_ptr; + + // Each lane writes (lane, lane+0.1, lane+0.2, lane+0.3) to its own column + if (wid == 0) { + float v0 = (float)lane; + float v1 = (float)lane + 0.1f; + float v2 = (float)lane + 0.2f; + float v3 = (float)lane + 0.3f; + tmem_store(tmem_base + lane, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3)); + tmem_fence_store(); + } + __syncthreads(); + + // Read back column 0 + if (wid == 0) { + uint32_t u0, u1, u2, u3; + tmem_load(tmem_base + 0, u0, u1, u2, u3); + if (lane == 0) { + sResult[0] = u32_to_f32(u0); + sResult[1] = u32_to_f32(u1); + sResult[2] = u32_to_f32(u2); + sResult[3] = u32_to_f32(u3); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + printf("Per-lane columns: lane 0 writes col 0, reads [%.1f,%.1f,%.1f,%.1f]\n", + sResult[0], sResult[1], sResult[2], sResult[3]); + // Expected: 0.0, 0.1, 0.2, 0.3 (if lane 0's data is in col 0) + // Actual may differ if lane mapping is different + printf(" Expected for lane 0: [0.0, 0.1, 0.2, 0.3]\n"); + } + + if (wid == 0) tmem_dealloc(tmem_base, 32); +} + +int main() { + printf("=== TMEM Lane Mapping Diagnostics ===\n\n"); + + printf("Test 1: Lane-to-position mapping (1 column, all lanes)...\n"); + test_lane_mapping<<<1, 64, 4096>>>(); + cudaDeviceSynchronize(); + printf("\n"); + + printf("Test 2: SMEM → TMEM → regs round-trip (4 values, lane 0 only)...\n"); + test_smem_to_tmem_roundtrip<<<1, 64, 4096>>>(); + cudaDeviceSynchronize(); + printf("\n"); + + printf("Test 3: Per-lane column write + read back col 0...\n"); + test_per_lane_columns<<<1, 64, 4096>>>(); + cudaDeviceSynchronize(); + printf("\n"); + + cudaError_t err = cudaDeviceSynchronize(); + printf("Final sync: %s\n", cudaGetErrorString(err)); + return 0; +}