Files
nvfp4-megamoe-kernel/tests/unit/test_tma_driver.cu

126 lines
4.2 KiB
Plaintext

/**
* Minimal TMA load test — compiled manually with -lcuda
* Tests: CUtensorMap creation + cp.async.bulk.tensor.2d PTX + mbarrier sync
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdio>
#include <cstring>
#include <cmath>
typedef unsigned short bf16_t;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int ROWS = 128, COLS = 16;
constexpr int TILE_ROWS = 128, TILE_COLS = 16;
__global__ void __launch_bounds__(32)
tma_test_kernel(const CUtensorMap* __restrict__ tma_desc, bf16_t* gmem_dst) {
extern __shared__ char sbuf[];
uint64_t* sMbar = (uint64_t*)sbuf;
bf16_t* sData = (bf16_t*)(((uintptr_t)(sbuf + 8) + 127) & ~(uintptr_t)127);
// Init mbarrier
if (threadIdx.x == 0) {
asm volatile("mbarrier.init.shared.b64 [%0], %1;"
:: "r"((uint32_t)__cvta_generic_to_shared(sMbar)), "r"(1));
}
__syncthreads();
// Issue TMA load
if (threadIdx.x == 0) {
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sData);
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
uint64_t tma_addr = (uint64_t)tma_desc;
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
"[%0], [%1, {%3, %4}], [%2];"
:: "r"(smem_addr), "l"(tma_addr), "r"(mbar_addr), "r"(0), "r"(0)
: "memory"
);
}
__syncthreads();
// Wait for TMA
if (threadIdx.x == 0) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"LOOP:\n\t"
"mbarrier.try_wait.parity.shared.b64 p, [%0], %1;\n\t"
"@p bra DONE;\n\t"
"bra LOOP;\n\t"
"DONE:\n\t"
"}"
:: "r"((uint32_t)__cvta_generic_to_shared(sMbar)), "r"(0)
: "memory"
);
}
__syncthreads();
// Copy SMEM → GMEM
for (int i = threadIdx.x; i < ROWS * COLS; i += 32) {
gmem_dst[i] = sData[i];
}
}
int main() {
printf("=== TMA Load Test (driver API) ===\n");
constexpr int TOTAL = ROWS * COLS;
constexpr int DATA_BYTES = TOTAL * sizeof(bf16_t);
bf16_t* h_src = (bf16_t*)malloc(DATA_BYTES);
bf16_t* h_dst = (bf16_t*)calloc(TOTAL, sizeof(bf16_t));
srand(42);
for (int i = 0; i < TOTAL; i++) h_src[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_src, *d_dst;
cudaMalloc(&d_src, DATA_BYTES);
cudaMalloc(&d_dst, DATA_BYTES);
cudaMemcpy(d_src, h_src, DATA_BYTES, cudaMemcpyHostToDevice);
// Create CUtensorMap
// globalDim: sizes in ELEMENTS (x=cols, y=rows)
// globalStrides: strides in BYTES
uint64_t gdim[] = {(uint64_t)COLS, (uint64_t)ROWS};
uint64_t gstr[] = {2, (uint64_t)COLS * 2}; // BF16=2 bytes/element, row stride = COLS*2
uint32_t tdim[] = {TILE_COLS, TILE_ROWS};
uint32_t tstr[] = {1, TILE_COLS};
CUtensorMap tma_desc_host;
CUresult res = cuTensorMapEncodeTiled(
&tma_desc_host, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2,
d_src, gdim, gstr, tdim, tstr,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
if (res != CUDA_SUCCESS) { printf("cuTensorMapEncodeTiled FAILED: %d\n", res); return 1; }
printf("CUtensorMap created OK\n");
CUtensorMap* d_tma_desc;
cudaMalloc(&d_tma_desc, sizeof(CUtensorMap));
cudaMemcpy(d_tma_desc, &tma_desc_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
int smem = 8 + 128 + DATA_BYTES + 256;
tma_test_kernel<<<1, 32, smem>>>(d_tma_desc, d_dst);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_dst, d_dst, DATA_BYTES, cudaMemcpyDeviceToHost);
int mismatches = 0;
for (int i = 0; i < TOTAL; i++) if (h_src[i] != h_dst[i]) mismatches++;
printf("Mismatches: %d / %d\n", mismatches, TOTAL);
printf("Test %s\n", mismatches == 0 ? "PASSED" : "FAILED");
cudaFree(d_src); cudaFree(d_dst); cudaFree(d_tma_desc);
free(h_src); free(h_dst);
return 0;
}