Files
nvfp4-megamoe-kernel/tests/unit/test_tma_verify.cu

151 lines
5.5 KiB
Plaintext

/**
* Verify TMA load correctness by comparing against direct GMEM read.
* Isolates TMA + mbarrier + canonical write from the rest of the FMHA pipeline.
*/
#include <cuda_runtime.h>
#include <cuda.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
#include "dsv4/kernels/attention/fmha_tma.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int ROWS = 128;
constexpr int COLS = 16; // one TMA sub-tile
/**
* Test 1: TMA load of (128, 16) tile, compare against direct read.
*/
__global__ void test_tma_load_kernel(
const bf16_t* __restrict__ gmem_data,
bf16_t* __restrict__ result_canonical, // output: canonical layout from TMA path
bf16_t* __restrict__ result_direct, // output: canonical layout from direct path
CUtensorMap* __restrict__ tma_desc
) {
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
bf16_t* sData_tma = (bf16_t*)(sbuf + off); off += ROWS * COLS * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127; // 128-byte align for TMA
bf16_t* sData_canonical = (bf16_t*)(sbuf + off);
const int tid = threadIdx.x;
const int lane = tid % 32;
const int wid = tid / 32;
// Set up mbarrier
__shared__ uint64_t sMbar[1];
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
if (tid == 0) {
tma_mbarrier_init(mbar_addr, 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
// --- TMA path ---
if (wid == 0 && lane == 0) {
uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sData_tma);
tma_load_2d(smem_dst, (uint64_t)tma_desc, mbar_addr, 0, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, ROWS * COLS * 2);
}
tma_mbarrier_wait(mbar_addr, 0);
__syncthreads();
// Convert to canonical
if (wid == 0) write_smem_canonical<ROWS, COLS, 32>(sData_canonical, sData_tma);
__syncthreads();
// Write canonical result to GMEM
for (int i = tid; i < ROWS * COLS; i += 128) {
result_canonical[i] = sData_canonical[i];
}
// --- Direct path ---
// Zero sData_canonical, fill from GMEM, convert
if (wid == 0) {
for (int i = lane; i < ROWS * COLS; i += 32) {
sData_tma[i] = gmem_data[i]; // direct GMEM read (row-major)
}
}
__syncthreads();
if (wid == 0) write_smem_canonical<ROWS, COLS, 32>(sData_canonical, sData_tma);
__syncthreads();
for (int i = tid; i < ROWS * COLS; i += 128) {
result_direct[i] = sData_canonical[i];
}
}
int main() {
printf("TMA Load Verification Test\n");
// Allocate host data
bf16_t* h_data = (bf16_t*)malloc(ROWS * COLS * sizeof(bf16_t));
srand(42);
for (int i = 0; i < ROWS * COLS; i++) h_data[i] = f32_to_bf16_host((float)(rand() % 100) / 100.0f - 0.5f);
// Allocate device
bf16_t *d_data, *d_result_tma, *d_result_direct;
cudaMalloc(&d_data, ROWS * COLS * sizeof(bf16_t));
cudaMalloc(&d_result_tma, ROWS * COLS * sizeof(bf16_t));
cudaMalloc(&d_result_direct, ROWS * COLS * sizeof(bf16_t));
cudaMemcpy(d_data, h_data, ROWS * COLS * sizeof(bf16_t), cudaMemcpyHostToDevice);
// Create TMA descriptor for (128, 16) BF16
CUtensorMap tma_desc;
CUtensorMap* d_tma_desc;
if (!create_tma_desc_2d_bf16(&tma_desc, d_data, ROWS, COLS, ROWS, COLS)) {
printf("TMA descriptor creation FAILED\n");
return 1;
}
cudaMalloc(&d_tma_desc, sizeof(CUtensorMap));
cudaMemcpy(d_tma_desc, &tma_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
int smem = ROWS * COLS * 2 * 2 + 128; // two SMEM buffers + slack
test_tma_load_kernel<<<1, 128, smem>>>(d_data, d_result_tma, d_result_direct, d_tma_desc);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
return 1;
}
// Compare
bf16_t* h_tma = (bf16_t*)malloc(ROWS * COLS * sizeof(bf16_t));
bf16_t* h_direct = (bf16_t*)malloc(ROWS * COLS * sizeof(bf16_t));
cudaMemcpy(h_tma, d_result_tma, ROWS * COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost);
cudaMemcpy(h_direct, d_result_direct, ROWS * COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost);
int mismatches = 0;
for (int i = 0; i < ROWS * COLS; i++) {
if (h_tma[i] != h_direct[i]) mismatches++;
}
printf("TMA vs Direct canonical: %d mismatches out of %d\n", mismatches, ROWS * COLS);
// Also compare TMA canonical against host data directly (row-major → canonical)
// For (128, 16): CORES_MN=16, CORES_K=2
// canonical[core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c]
int canonical_mismatches = 0;
for (int i = 0; i < ROWS * COLS; i++) {
int r = i / COLS, c = i % COLS;
int core_mn = r / 8, core_k = c / 8, local_r = r % 8, local_c = c % 8;
int canon_idx = core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c;
if (h_data[i] != h_tma[canon_idx]) canonical_mismatches++;
}
printf("Host data → canonical: %d mismatches (checks mapping correctness)\n", canonical_mismatches);
printf("%s\n", mismatches == 0 ? "PASSED" : "FAILED");
cudaFree(d_data); cudaFree(d_result_tma); cudaFree(d_result_direct); cudaFree(d_tma_desc);
free(h_data); free(h_tma); free(h_direct);
return mismatches == 0 ? 0 : 1;
}