diff --git a/tests/unit/test_tma_kload.cu b/tests/unit/test_tma_kload.cu new file mode 100644 index 00000000..ca8b089a --- /dev/null +++ b/tests/unit/test_tma_kload.cu @@ -0,0 +1,133 @@ +/** + * Absolute minimum TMA + QK test: + * 1. Load K via TMA into sTmaBuf + * 2. Convert to canonical sK0 + * 3. Write sK0 to GMEM for verification + * No TMEM, no MMA. + */ + +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" +#include "dsv4/kernels/attention/fmha_tma.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int MMA_K = MMA_K_BF16; +constexpr int CORES_MN = 16; +constexpr int TILE_SZ = 128 * MMA_K; + +__global__ void __launch_bounds__(128) +test_tma_kload_kernel( + bf16_t* __restrict__ out_canonical, // (128, 16) canonical + bf16_t* __restrict__ out_rowmajor, // (128, 16) row-major + CUtensorMap* __restrict__ tma_k +) { + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + + // Simple SMEM: sTmaBuf (128,16) + sCanonical (128,16) + sMbar + extern __shared__ __align__(128) char sbuf[]; + bf16_t* sTmaBuf = (bf16_t*)(sbuf); + bf16_t* sCanonical = (bf16_t*)(sbuf + TILE_SZ * sizeof(bf16_t)); + uint64_t* sMbar = (uint64_t*)(sbuf + 2 * TILE_SZ * sizeof(bf16_t)); + + const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + + if (tid == 0) { + tma_mbarrier_init(mbar_addr, 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + } + __syncthreads(); + + // TMA load: (128, 16) tile at coord {0, 0} + if (wid == 0 && lane == 0) { + uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sTmaBuf); + tma_load_2d(smem_dst, (uint64_t)tma_k, mbar_addr, 0, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t)); + } + tma_mbarrier_wait(mbar_addr, 0); + __syncthreads(); + + // Convert to canonical + for (int i = tid; i < TILE_SZ; i += 128) sCanonical[i] = 0; + for (int i = tid; i < SK * MMA_K; i += 128) { + int r = i / MMA_K, c = i % MMA_K; + int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8; + sCanonical[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i]; + } + __syncthreads(); + + // Write canonical to GMEM + for (int i = tid; i < TILE_SZ; i += 128) out_canonical[i] = sCanonical[i]; + // Write row-major to GMEM + for (int i = tid; i < TILE_SZ; i += 128) out_rowmajor[i] = sTmaBuf[i]; +} + +int main() { + printf("TMA K-load only (HD=%d)\n", HD); + + bf16_t* h_k = (bf16_t*)calloc(SK * MMA_K, sizeof(bf16_t)); + srand(42); + for (int i = 0; i < SK * MMA_K; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_k, *d_out_c, *d_out_r; + cudaMalloc(&d_k, SK * MMA_K * sizeof(bf16_t)); + cudaMalloc(&d_out_c, TILE_SZ * sizeof(bf16_t)); + cudaMalloc(&d_out_r, TILE_SZ * sizeof(bf16_t)); + cudaMemcpy(d_k, h_k, SK * MMA_K * sizeof(bf16_t), cudaMemcpyHostToDevice); + + CUtensorMap tma_k; + CUtensorMap* d_tma_k; + if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, (uint64_t)MMA_K, 128, 16)) { + printf("TMA desc FAILED\n"); return 1; + } + cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); + cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + + int smem = TILE_SZ * 2 * sizeof(bf16_t) + 16; + test_tma_kload_kernel<<<1, 128, smem>>>(d_out_c, d_out_r, d_tma_k); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + bf16_t* h_c = (bf16_t*)malloc(TILE_SZ * sizeof(bf16_t)); + bf16_t* h_r = (bf16_t*)malloc(TILE_SZ * sizeof(bf16_t)); + cudaMemcpy(h_c, d_out_c, TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_r, d_out_r, TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost); + + // Verify row-major + int rm_mismatches = 0; + for (int i = 0; i < SK * MMA_K; i++) { + if (h_r[i] != h_k[i]) rm_mismatches++; + } + printf("Row-major: %d mismatches out of %d\n", rm_mismatches, SK * MMA_K); + + // Verify canonical + int cn_mismatches = 0; + for (int r = 0; r < SK; r++) { + for (int c = 0; c < MMA_K; c++) { + int ck = c/8, lc = c%8, tmn = r/8, lr = r%8; + int canon_idx = ck*CORES_MN*64 + tmn*64 + lr*8 + lc; + if (h_k[r*MMA_K + c] != h_c[canon_idx]) cn_mismatches++; + } + } + printf("Canonical: %d mismatches out of %d\n", cn_mismatches, SK * MMA_K); + + printf("%s\n", (rm_mismatches + cn_mismatches == 0) ? "PASSED" : "FAILED"); + return (rm_mismatches + cn_mismatches == 0) ? 0 : 1; +}