Files
nvfp4-megamoe-kernel/tests/unit/test_tmem_row_offset.cu

154 lines
5.3 KiB
Plaintext

/**
* Minimal test: verify TMEM 32x32b.x8 read with row_page offset.
*
* Write known data to TMEM, then read it back with different row offsets
* to verify that addr = tmem_base + (row_page*32 << 16) + col_group*8
* correctly addresses different 32-row pages.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cstdint>
using bf16_t = unsigned short;
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
}
__device__ __forceinline__ float bf16_to_f32(bf16_t h) {
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
}
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
:: "r"(smem_ptr), "r"(num_cols));
}
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
:: "r"(tmem_ptr), "r"(num_cols));
}
// Test 1: Write to TMEM via 32x32b.x8, then read back with row offset
__global__ void test_tmem_row_offset(float* results) {
// 1 warp = 32 threads
const int lane = threadIdx.x;
const int TMEM_N = 128;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
// Alloc TMEM
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
if (lane == 0) tmem_alloc(smem_ptr, TMEM_N);
__syncwarp();
uint32_t tb = *sTmemBase;
// Write: each lane l writes its lane_id as float to 8 columns
// 32x32b.x8 store: lane l writes to its row across 8 columns
// First write to columns 0-7
{
float vals[8];
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 8 + c);
uint32_t ivals[8];
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + 0),
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
}
// Now try reading with row offset 0 (same data)
{
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Store results for lane 0 and lane 1
if (lane < 2) {
for (int c = 0; c < 8; c++) {
results[lane * 8 + c] = tmp[c];
}
}
}
// Try reading with row offset 32 (row_page=1, rows 32-63)
// If this format works, lane 0 should read row 32's data
// (which we didn't write, so it should be 0 or garbage)
{
uint32_t row_offset = 32u << 16;
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + row_offset + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane < 2) {
for (int c = 0; c < 8; c++) {
results[16 + lane * 8 + c] = tmp[c];
}
}
}
// Dealloc
if (lane == 0) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("TMEM row offset addressing test\n");
printf("================================\n");
float* d_results;
cudaMalloc(&d_results, 64 * sizeof(float));
cudaMemset(d_results, 0, 64 * sizeof(float));
int smem = 256;
test_tmem_row_offset<<<1, 32, smem>>>(d_results);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
printf("Test FAILED (kernel crash/hang)\n");
cudaFree(d_results);
return 1;
}
float h_results[64];
cudaMemcpy(h_results, d_results, 64 * sizeof(float), cudaMemcpyDeviceToHost);
printf("Read with row_offset=0 (should match written data):\n");
printf(" Lane 0: ");
for (int c = 0; c < 8; c++) printf("%.1f ", h_results[c]);
printf("\n Lane 1: ");
for (int c = 0; c < 8; c++) printf("%.1f ", h_results[8 + c]);
printf("\n");
printf("Read with row_offset=32<<16 (rows 32-63, unwritten):\n");
printf(" Lane 0: ");
for (int c = 0; c < 8; c++) printf("%.1f ", h_results[16 + c]);
printf("\n Lane 1: ");
for (int c = 0; c < 8; c++) printf("%.1f ", h_results[24 + c]);
printf("\n");
// Verify row_offset=0 reads match what we wrote
int pass = 1;
for (int c = 0; c < 8; c++) {
if (fabsf(h_results[c] - (float)c) > 0.01f) { pass = 0; break; }
}
for (int c = 0; c < 8; c++) {
if (fabsf(h_results[8 + c] - (float)(8 + c)) > 0.01f) { pass = 0; break; }
}
printf("\nRow_offset=0 verification: %s\n", pass ? "PASSED" : "FAILED");
printf("Row_offset=32<<16 result: %s\n",
(h_results[16] == 0.0f) ? "all zeros (unwritten, addressing may or may not work)" : "non-zero (addressing works!)");
cudaFree(d_results);
return pass ? 0 : 1;
}