test: add TMEM lane mapping diagnostics

This commit is contained in:
2026-05-28 07:42:16 +00:00
parent 33cedbee0a
commit 593bc25afa

View File

@@ -0,0 +1,247 @@
/**
* TMEM round-trip test — verify SMEM → TMEM → regs → GMEM pipeline.
* Write known FP32 values to TMEM via warp-collective store,
* read them back via warp-collective load, compare with expected.
*/
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstring>
// Minimal TMEM ops (same as fmha_common.cuh)
typedef unsigned short bf16_t;
__device__ __forceinline__ uint32_t f32_to_u32(float f) { uint32_t u; memcpy(&u,&f,4); return u; }
__device__ __forceinline__ float u32_to_f32(uint32_t u) { float f; memcpy(&f,&u,4); return f; }
constexpr int WARP = 32;
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
:: "r"(smem_ptr), "r"(num_cols));
}
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
:: "r"(tmem_ptr), "r"(num_cols));
}
__device__ void tmem_fence_store() {
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
}
__device__ void tmem_load(uint32_t col, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) {
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(col));
}
__device__ void tmem_store(uint32_t col, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) {
asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};"
:: "r"(col), "r"(r0), "r"(r1), "r"(r2), "r"(r3));
}
/**
* Test 1: Understand the lane-to-position mapping.
*
* Each lane writes a unique identifier to its 4 registers.
* Lane i writes: (i*4+0, i*4+1, i*4+2, i*4+3) as FP32.
* Then reads back and prints.
*
* This tells us which lane's data ends up where in the TMEM column.
*/
__global__ void test_lane_mapping() {
extern __shared__ char sbuf[];
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
float* sData = (float*)(sbuf + 4); // 16 floats for our test
uint32_t* sResults = (uint32_t*)(sbuf + 4 + 16*4); // read-back results
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
// Alloc TMEM
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
tmem_alloc(smem_ptr, 32);
}
__syncthreads();
uint32_t tmem_base = *tmem_base_ptr;
// Write to TMEM: lane i writes (i*4, i*4+1, i*4+2, i*4+3) to column 0
// ALL 32 lanes participate (warp-collective)
if (wid == 0) {
float v0 = (float)(lane * 4 + 0);
float v1 = (float)(lane * 4 + 1);
float v2 = (float)(lane * 4 + 2);
float v3 = (float)(lane * 4 + 3);
tmem_store(tmem_base + 0, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
tmem_fence_store();
}
__syncthreads();
// Read back from TMEM column 0
if (wid == 0) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + 0, u0, u1, u2, u3);
// Each lane stores its results to SMEM
sResults[lane * 4 + 0] = u0;
sResults[lane * 4 + 1] = u1;
sResults[lane * 4 + 2] = u2;
sResults[lane * 4 + 3] = u3;
}
__syncthreads();
// Print results from thread 0 (only print first few lanes)
if (threadIdx.x == 0) {
printf("Lane mapping test (column 0):\n");
for (int i = 0; i < 4; i++) {
float r0 = u32_to_f32(sResults[i*4+0]);
float r1 = u32_to_f32(sResults[i*4+1]);
float r2 = u32_to_f32(sResults[i*4+2]);
float r3 = u32_to_f32(sResults[i*4+3]);
printf(" Lane %2d wrote (%d,%d,%d,%d), read back (%.0f,%.0f,%.0f,%.0f)\n",
i, i*4+0, i*4+1, i*4+2, i*4+3, r0, r1, r2, r3);
}
}
if (wid == 0) tmem_dealloc(tmem_base, 32);
}
/**
* Test 2: Write HD=4 values (1 column) to TMEM from SMEM,
* read back, compare. This is the minimal FMHA epilogue test.
*/
__global__ void test_smem_to_tmem_roundtrip() {
extern __shared__ char sbuf[];
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
float* sData = (float*)(sbuf + 4);
float* sResult = (float*)(sbuf + 4 + 4*4);
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
// Init SMEM data: [1.0, 2.0, 3.0, 4.0] + [0, 0, 0, 0] (padding to 8)
if (threadIdx.x == 0) {
sData[0] = 1.0f; sData[1] = 2.0f; sData[2] = 3.0f; sData[3] = 4.0f;
}
__syncthreads();
// Alloc TMEM (32 cols minimum)
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
tmem_alloc(smem_ptr, 32);
}
__syncthreads();
uint32_t tmem_base = *tmem_base_ptr;
// Write sData[0..3] to TMEM column 0, rows 0-3
// Lane 0 handles these 4 values, all lanes participate
if (wid == 0) {
// Lane 0: write sData[0..3]. Other lanes write 0.
uint32_t u0 = (lane == 0) ? f32_to_u32(sData[0]) : 0;
uint32_t u1 = (lane == 0) ? f32_to_u32(sData[1]) : 0;
uint32_t u2 = (lane == 0) ? f32_to_u32(sData[2]) : 0;
uint32_t u3 = (lane == 0) ? f32_to_u32(sData[3]) : 0;
tmem_store(tmem_base + 0, u0, u1, u2, u3);
tmem_fence_store();
}
__syncthreads();
// Read back
if (wid == 0) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + 0, u0, u1, u2, u3);
// Lane 0 reads and stores to SMEM
if (lane == 0) {
sResult[0] = u32_to_f32(u0);
sResult[1] = u32_to_f32(u1);
sResult[2] = u32_to_f32(u2);
sResult[3] = u32_to_f32(u3);
}
}
__syncthreads();
if (threadIdx.x == 0) {
printf("SMEM→TMEM round-trip: wrote [1,2,3,4], read [%.1f,%.1f,%.1f,%.1f]\n",
sResult[0], sResult[1], sResult[2], sResult[3]);
int ok = (fabsf(sResult[0]-1.0f)<0.01f && fabsf(sResult[1]-2.0f)<0.01f &&
fabsf(sResult[2]-3.0f)<0.01f && fabsf(sResult[3]-4.0f)<0.01f);
printf(" %s\n", ok ? "✅ PASS" : "❌ FAIL");
}
if (wid == 0) tmem_dealloc(tmem_base, 32);
}
/**
* Test 3: Multi-lane write — understand how lanes map to data positions.
* We write 32 different columns (one per lane), each lane writes its lane_id
* to column = tmem_base + lane. Then we read back column 0 and see what
* lane 0 gets.
*/
__global__ void test_per_lane_columns() {
extern __shared__ char sbuf[];
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
float* sResult = (float*)(sbuf + 4);
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
tmem_alloc(smem_ptr, 32);
}
__syncthreads();
uint32_t tmem_base = *tmem_base_ptr;
// Each lane writes (lane, lane+0.1, lane+0.2, lane+0.3) to its own column
if (wid == 0) {
float v0 = (float)lane;
float v1 = (float)lane + 0.1f;
float v2 = (float)lane + 0.2f;
float v3 = (float)lane + 0.3f;
tmem_store(tmem_base + lane, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
tmem_fence_store();
}
__syncthreads();
// Read back column 0
if (wid == 0) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + 0, u0, u1, u2, u3);
if (lane == 0) {
sResult[0] = u32_to_f32(u0);
sResult[1] = u32_to_f32(u1);
sResult[2] = u32_to_f32(u2);
sResult[3] = u32_to_f32(u3);
}
}
__syncthreads();
if (threadIdx.x == 0) {
printf("Per-lane columns: lane 0 writes col 0, reads [%.1f,%.1f,%.1f,%.1f]\n",
sResult[0], sResult[1], sResult[2], sResult[3]);
// Expected: 0.0, 0.1, 0.2, 0.3 (if lane 0's data is in col 0)
// Actual may differ if lane mapping is different
printf(" Expected for lane 0: [0.0, 0.1, 0.2, 0.3]\n");
}
if (wid == 0) tmem_dealloc(tmem_base, 32);
}
int main() {
printf("=== TMEM Lane Mapping Diagnostics ===\n\n");
printf("Test 1: Lane-to-position mapping (1 column, all lanes)...\n");
test_lane_mapping<<<1, 64, 4096>>>();
cudaDeviceSynchronize();
printf("\n");
printf("Test 2: SMEM → TMEM → regs round-trip (4 values, lane 0 only)...\n");
test_smem_to_tmem_roundtrip<<<1, 64, 4096>>>();
cudaDeviceSynchronize();
printf("\n");
printf("Test 3: Per-lane column write + read back col 0...\n");
test_per_lane_columns<<<1, 64, 4096>>>();
cudaDeviceSynchronize();
printf("\n");
cudaError_t err = cudaDeviceSynchronize();
printf("Final sync: %s\n", cudaGetErrorString(err));
return 0;
}