Files
nvfp4-megamoe-kernel/tests/unit/test_tma_kload.cu

134 lines
4.7 KiB
Plaintext

/**
* Absolute minimum TMA + QK test:
* 1. Load K via TMA into sTmaBuf
* 2. Convert to canonical sK0
* 3. Write sK0 to GMEM for verification
* No TMEM, no MMA.
*/
#include <cuda_runtime.h>
#include <cuda.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#ifndef HD_VAL
#define HD_VAL 64
#endif
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
#include "dsv4/kernels/attention/fmha_tma.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = HD_VAL;
constexpr int SK = 128;
constexpr int MMA_K = MMA_K_BF16;
constexpr int CORES_MN = 16;
constexpr int TILE_SZ = 128 * MMA_K;
__global__ void __launch_bounds__(128)
test_tma_kload_kernel(
bf16_t* __restrict__ out_canonical, // (128, 16) canonical
bf16_t* __restrict__ out_rowmajor, // (128, 16) row-major
CUtensorMap* __restrict__ tma_k
) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
// Simple SMEM: sTmaBuf (128,16) + sCanonical (128,16) + sMbar
extern __shared__ __align__(128) char sbuf[];
bf16_t* sTmaBuf = (bf16_t*)(sbuf);
bf16_t* sCanonical = (bf16_t*)(sbuf + TILE_SZ * sizeof(bf16_t));
uint64_t* sMbar = (uint64_t*)(sbuf + 2 * TILE_SZ * sizeof(bf16_t));
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
if (tid == 0) {
tma_mbarrier_init(mbar_addr, 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
// TMA load: (128, 16) tile at coord {0, 0}
if (wid == 0 && lane == 0) {
uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sTmaBuf);
tma_load_2d(smem_dst, (uint64_t)tma_k, mbar_addr, 0, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t));
}
tma_mbarrier_wait(mbar_addr, 0);
__syncthreads();
// Convert to canonical
for (int i = tid; i < TILE_SZ; i += 128) sCanonical[i] = 0;
for (int i = tid; i < SK * MMA_K; i += 128) {
int r = i / MMA_K, c = i % MMA_K;
int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8;
sCanonical[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i];
}
__syncthreads();
// Write canonical to GMEM
for (int i = tid; i < TILE_SZ; i += 128) out_canonical[i] = sCanonical[i];
// Write row-major to GMEM
for (int i = tid; i < TILE_SZ; i += 128) out_rowmajor[i] = sTmaBuf[i];
}
int main() {
printf("TMA K-load only (HD=%d)\n", HD);
bf16_t* h_k = (bf16_t*)calloc(SK * MMA_K, sizeof(bf16_t));
srand(42);
for (int i = 0; i < SK * MMA_K; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_k, *d_out_c, *d_out_r;
cudaMalloc(&d_k, SK * MMA_K * sizeof(bf16_t));
cudaMalloc(&d_out_c, TILE_SZ * sizeof(bf16_t));
cudaMalloc(&d_out_r, TILE_SZ * sizeof(bf16_t));
cudaMemcpy(d_k, h_k, SK * MMA_K * sizeof(bf16_t), cudaMemcpyHostToDevice);
CUtensorMap tma_k;
CUtensorMap* d_tma_k;
if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, (uint64_t)MMA_K, 128, 16)) {
printf("TMA desc FAILED\n"); return 1;
}
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
int smem = TILE_SZ * 2 * sizeof(bf16_t) + 16;
test_tma_kload_kernel<<<1, 128, smem>>>(d_out_c, d_out_r, d_tma_k);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
bf16_t* h_c = (bf16_t*)malloc(TILE_SZ * sizeof(bf16_t));
bf16_t* h_r = (bf16_t*)malloc(TILE_SZ * sizeof(bf16_t));
cudaMemcpy(h_c, d_out_c, TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost);
cudaMemcpy(h_r, d_out_r, TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost);
// Verify row-major
int rm_mismatches = 0;
for (int i = 0; i < SK * MMA_K; i++) {
if (h_r[i] != h_k[i]) rm_mismatches++;
}
printf("Row-major: %d mismatches out of %d\n", rm_mismatches, SK * MMA_K);
// Verify canonical
int cn_mismatches = 0;
for (int r = 0; r < SK; r++) {
for (int c = 0; c < MMA_K; c++) {
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
int canon_idx = ck*CORES_MN*64 + tmn*64 + lr*8 + lc;
if (h_k[r*MMA_K + c] != h_c[canon_idx]) cn_mismatches++;
}
}
printf("Canonical: %d mismatches out of %d\n", cn_mismatches, SK * MMA_K);
printf("%s\n", (rm_mismatches + cn_mismatches == 0) ? "PASSED" : "FAILED");
return (rm_mismatches + cn_mismatches == 0) ? 0 : 1;
}