test: add TMEM lane mapping diagnostics
This commit is contained in:
247
tests/unit/test_tmem_lane_mapping.cu
Normal file
247
tests/unit/test_tmem_lane_mapping.cu
Normal file
@@ -0,0 +1,247 @@
|
||||
/**
|
||||
* TMEM round-trip test — verify SMEM → TMEM → regs → GMEM pipeline.
|
||||
* Write known FP32 values to TMEM via warp-collective store,
|
||||
* read them back via warp-collective load, compare with expected.
|
||||
*/
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
||||
// Minimal TMEM ops (same as fmha_common.cuh)
|
||||
typedef unsigned short bf16_t;
|
||||
|
||||
__device__ __forceinline__ uint32_t f32_to_u32(float f) { uint32_t u; memcpy(&u,&f,4); return u; }
|
||||
__device__ __forceinline__ float u32_to_f32(uint32_t u) { float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
constexpr int WARP = 32;
|
||||
|
||||
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:: "r"(smem_ptr), "r"(num_cols));
|
||||
}
|
||||
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
|
||||
:: "r"(tmem_ptr), "r"(num_cols));
|
||||
}
|
||||
__device__ void tmem_fence_store() {
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
||||
}
|
||||
__device__ void tmem_load(uint32_t col, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) {
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(col));
|
||||
}
|
||||
__device__ void tmem_store(uint32_t col, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) {
|
||||
asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};"
|
||||
:: "r"(col), "r"(r0), "r"(r1), "r"(r2), "r"(r3));
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 1: Understand the lane-to-position mapping.
|
||||
*
|
||||
* Each lane writes a unique identifier to its 4 registers.
|
||||
* Lane i writes: (i*4+0, i*4+1, i*4+2, i*4+3) as FP32.
|
||||
* Then reads back and prints.
|
||||
*
|
||||
* This tells us which lane's data ends up where in the TMEM column.
|
||||
*/
|
||||
__global__ void test_lane_mapping() {
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
|
||||
float* sData = (float*)(sbuf + 4); // 16 floats for our test
|
||||
uint32_t* sResults = (uint32_t*)(sbuf + 4 + 16*4); // read-back results
|
||||
|
||||
int lane = threadIdx.x % 32;
|
||||
int wid = threadIdx.x / 32;
|
||||
|
||||
// Alloc TMEM
|
||||
if (wid == 0) {
|
||||
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
|
||||
tmem_alloc(smem_ptr, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tmem_base = *tmem_base_ptr;
|
||||
|
||||
// Write to TMEM: lane i writes (i*4, i*4+1, i*4+2, i*4+3) to column 0
|
||||
// ALL 32 lanes participate (warp-collective)
|
||||
if (wid == 0) {
|
||||
float v0 = (float)(lane * 4 + 0);
|
||||
float v1 = (float)(lane * 4 + 1);
|
||||
float v2 = (float)(lane * 4 + 2);
|
||||
float v3 = (float)(lane * 4 + 3);
|
||||
tmem_store(tmem_base + 0, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Read back from TMEM column 0
|
||||
if (wid == 0) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_base + 0, u0, u1, u2, u3);
|
||||
// Each lane stores its results to SMEM
|
||||
sResults[lane * 4 + 0] = u0;
|
||||
sResults[lane * 4 + 1] = u1;
|
||||
sResults[lane * 4 + 2] = u2;
|
||||
sResults[lane * 4 + 3] = u3;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Print results from thread 0 (only print first few lanes)
|
||||
if (threadIdx.x == 0) {
|
||||
printf("Lane mapping test (column 0):\n");
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float r0 = u32_to_f32(sResults[i*4+0]);
|
||||
float r1 = u32_to_f32(sResults[i*4+1]);
|
||||
float r2 = u32_to_f32(sResults[i*4+2]);
|
||||
float r3 = u32_to_f32(sResults[i*4+3]);
|
||||
printf(" Lane %2d wrote (%d,%d,%d,%d), read back (%.0f,%.0f,%.0f,%.0f)\n",
|
||||
i, i*4+0, i*4+1, i*4+2, i*4+3, r0, r1, r2, r3);
|
||||
}
|
||||
}
|
||||
|
||||
if (wid == 0) tmem_dealloc(tmem_base, 32);
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 2: Write HD=4 values (1 column) to TMEM from SMEM,
|
||||
* read back, compare. This is the minimal FMHA epilogue test.
|
||||
*/
|
||||
__global__ void test_smem_to_tmem_roundtrip() {
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
|
||||
float* sData = (float*)(sbuf + 4);
|
||||
float* sResult = (float*)(sbuf + 4 + 4*4);
|
||||
|
||||
int lane = threadIdx.x % 32;
|
||||
int wid = threadIdx.x / 32;
|
||||
|
||||
// Init SMEM data: [1.0, 2.0, 3.0, 4.0] + [0, 0, 0, 0] (padding to 8)
|
||||
if (threadIdx.x == 0) {
|
||||
sData[0] = 1.0f; sData[1] = 2.0f; sData[2] = 3.0f; sData[3] = 4.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Alloc TMEM (32 cols minimum)
|
||||
if (wid == 0) {
|
||||
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
|
||||
tmem_alloc(smem_ptr, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tmem_base = *tmem_base_ptr;
|
||||
|
||||
// Write sData[0..3] to TMEM column 0, rows 0-3
|
||||
// Lane 0 handles these 4 values, all lanes participate
|
||||
if (wid == 0) {
|
||||
// Lane 0: write sData[0..3]. Other lanes write 0.
|
||||
uint32_t u0 = (lane == 0) ? f32_to_u32(sData[0]) : 0;
|
||||
uint32_t u1 = (lane == 0) ? f32_to_u32(sData[1]) : 0;
|
||||
uint32_t u2 = (lane == 0) ? f32_to_u32(sData[2]) : 0;
|
||||
uint32_t u3 = (lane == 0) ? f32_to_u32(sData[3]) : 0;
|
||||
tmem_store(tmem_base + 0, u0, u1, u2, u3);
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Read back
|
||||
if (wid == 0) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_base + 0, u0, u1, u2, u3);
|
||||
// Lane 0 reads and stores to SMEM
|
||||
if (lane == 0) {
|
||||
sResult[0] = u32_to_f32(u0);
|
||||
sResult[1] = u32_to_f32(u1);
|
||||
sResult[2] = u32_to_f32(u2);
|
||||
sResult[3] = u32_to_f32(u3);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
printf("SMEM→TMEM round-trip: wrote [1,2,3,4], read [%.1f,%.1f,%.1f,%.1f]\n",
|
||||
sResult[0], sResult[1], sResult[2], sResult[3]);
|
||||
int ok = (fabsf(sResult[0]-1.0f)<0.01f && fabsf(sResult[1]-2.0f)<0.01f &&
|
||||
fabsf(sResult[2]-3.0f)<0.01f && fabsf(sResult[3]-4.0f)<0.01f);
|
||||
printf(" %s\n", ok ? "✅ PASS" : "❌ FAIL");
|
||||
}
|
||||
|
||||
if (wid == 0) tmem_dealloc(tmem_base, 32);
|
||||
}
|
||||
|
||||
/**
|
||||
* Test 3: Multi-lane write — understand how lanes map to data positions.
|
||||
* We write 32 different columns (one per lane), each lane writes its lane_id
|
||||
* to column = tmem_base + lane. Then we read back column 0 and see what
|
||||
* lane 0 gets.
|
||||
*/
|
||||
__global__ void test_per_lane_columns() {
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
|
||||
float* sResult = (float*)(sbuf + 4);
|
||||
|
||||
int lane = threadIdx.x % 32;
|
||||
int wid = threadIdx.x / 32;
|
||||
|
||||
if (wid == 0) {
|
||||
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
|
||||
tmem_alloc(smem_ptr, 32);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tmem_base = *tmem_base_ptr;
|
||||
|
||||
// Each lane writes (lane, lane+0.1, lane+0.2, lane+0.3) to its own column
|
||||
if (wid == 0) {
|
||||
float v0 = (float)lane;
|
||||
float v1 = (float)lane + 0.1f;
|
||||
float v2 = (float)lane + 0.2f;
|
||||
float v3 = (float)lane + 0.3f;
|
||||
tmem_store(tmem_base + lane, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Read back column 0
|
||||
if (wid == 0) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_base + 0, u0, u1, u2, u3);
|
||||
if (lane == 0) {
|
||||
sResult[0] = u32_to_f32(u0);
|
||||
sResult[1] = u32_to_f32(u1);
|
||||
sResult[2] = u32_to_f32(u2);
|
||||
sResult[3] = u32_to_f32(u3);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
printf("Per-lane columns: lane 0 writes col 0, reads [%.1f,%.1f,%.1f,%.1f]\n",
|
||||
sResult[0], sResult[1], sResult[2], sResult[3]);
|
||||
// Expected: 0.0, 0.1, 0.2, 0.3 (if lane 0's data is in col 0)
|
||||
// Actual may differ if lane mapping is different
|
||||
printf(" Expected for lane 0: [0.0, 0.1, 0.2, 0.3]\n");
|
||||
}
|
||||
|
||||
if (wid == 0) tmem_dealloc(tmem_base, 32);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== TMEM Lane Mapping Diagnostics ===\n\n");
|
||||
|
||||
printf("Test 1: Lane-to-position mapping (1 column, all lanes)...\n");
|
||||
test_lane_mapping<<<1, 64, 4096>>>();
|
||||
cudaDeviceSynchronize();
|
||||
printf("\n");
|
||||
|
||||
printf("Test 2: SMEM → TMEM → regs round-trip (4 values, lane 0 only)...\n");
|
||||
test_smem_to_tmem_roundtrip<<<1, 64, 4096>>>();
|
||||
cudaDeviceSynchronize();
|
||||
printf("\n");
|
||||
|
||||
printf("Test 3: Per-lane column write + read back col 0...\n");
|
||||
test_per_lane_columns<<<1, 64, 4096>>>();
|
||||
cudaDeviceSynchronize();
|
||||
printf("\n");
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
printf("Final sync: %s\n", cudaGetErrorString(err));
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user