From 9dfada66262554b4dca8080da6ca29c778daeb4b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 19:28:23 +0000 Subject: [PATCH] test: TMA + canonical + QK GEMM incremental --- tests/unit/test_tma_qk.cu | 247 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 tests/unit/test_tma_qk.cu diff --git a/tests/unit/test_tma_qk.cu b/tests/unit/test_tma_qk.cu new file mode 100644 index 00000000..3a8244da --- /dev/null +++ b/tests/unit/test_tma_qk.cu @@ -0,0 +1,247 @@ +/** + * Incremental TMA + QK GEMM test. + * Step 1: TMA load K sub-tile (PROVEN WORKING) + * Step 2: Convert to canonical layout (ADD) + * Step 3: MMA QK (ADD) + * Step 4: Read S from TMEM (ADD) + * + * Test with HD=64, NUM_THREADS=128 (same as test_qk_softmax). + */ + +#include +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#ifndef NUM_THREADS +#define NUM_THREADS 128 +#endif + +typedef unsigned short bf16_t; +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int MMA_K = 16; +constexpr int BLOCK_MN = 128; +constexpr int TILE_SZ = 128 * MMA_K; // 2048 BF16 +constexpr int TMEM_N = (HD <= 128) ? 128 : 256; +constexpr int NKT = HD / MMA_K; +constexpr int CORES_MN = 16; +constexpr int NUM_READS = SK / 8; + +// ---- PTX helpers ---- +__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t count) { + asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(smem_mbar), "r"(count)); +} +__device__ __forceinline__ void tma_mbarrier_arrive_expect_tx(uint32_t smem_mbar, uint32_t tx_bytes) { + asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;" + :: "r"(smem_mbar), "r"(tx_bytes) : "memory"); +} +__device__ __forceinline__ void tma_load_2d(uint32_t dst, uint64_t desc, uint32_t mbar, int cx, int cy) { + asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%3, %4}], [%2];" :: "r"(dst), "l"(desc), "r"(mbar), "r"(cx), "r"(cy) : "memory"); +} +__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar, int phase) { + asm volatile("{\n\t.reg .pred P1;\n\tLAB_WAIT:mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t@P1 bra.uni DONE;\n\tbra.uni LAB_WAIT;\n\tDONE:\n\t}" :: "r"(smem_mbar), "r"(phase), "r"(0x989680) : "memory"); +} + +__device__ __forceinline__ bf16_t f32_to_bf16(float f) { bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; } +__device__ __forceinline__ float bf16_to_f32(bf16_t h) { float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f; } +__device__ __forceinline__ float wmax(float v) { for(int o=16;o>0;o>>=1) v=fmaxf(v,__shfl_xor_sync(0xFFFFFFFF,v,o)); return v; } +__device__ __forceinline__ float wsum(float v) { for(int o=16;o>0;o>>=1) v+=__shfl_xor_sync(0xFFFFFFFF,v,o); return v; } + +__device__ void tmem_alloc(uint32_t smem_ptr, int n) { + asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" :: "r"(smem_ptr), "r"(n)); +} +__device__ void tmem_dealloc(uint32_t tmem_ptr, int n) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" :: "r"(tmem_ptr), "r"(n)); +} + +__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; } +__device__ __forceinline__ uint64_t make_umma_desc_kmajor_none(uint32_t smem_addr, int block_mn) { + uint64_t desc = 0; + desc |= desc_encode(smem_addr) & 0x3FFF; + desc |= (desc_encode(block_mn * 16) & 0x3FFF) << 16; + desc |= (desc_encode(128) & 0x3FFF) << 32; + desc |= 1ULL << 46; + return desc; +} +__device__ __forceinline__ uint32_t make_idesc(int bm, int bn) { + return (1U<<4)|(1U<<7)|(1U<<10)|((uint32_t)(bn>>3)<<17)|((uint32_t)(bm>>4)<<24); +} +__device__ void umma_ss_f16(uint32_t tc, uint64_t da, uint64_t db, uint32_t idesc, bool acc) { + uint32_t sb = acc ? 0x3F800000u : 0u; + asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\ttcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p;\n\t}" + :: "r"(tc), "l"(da), "l"(db), "r"(idesc), "r"(sb), "r"(0), "r"(0), "r"(0), "r"(0)); +} + +__global__ void __launch_bounds__(NUM_THREADS) +test_tma_qk_kernel( + float* __restrict__ out_s, + const bf16_t* __restrict__ q, + CUtensorMap* __restrict__ tma_k, + int s_k +) { + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + + extern __shared__ __align__(128) char sbuf[]; + size_t off = 0; + uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off = 4; + off = (off + 127) & ~(size_t)127; + bf16_t* sQ0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + off = (off + 15) & ~(size_t)15; + uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8; + + // Init + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + if (tid == 0) { + tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + int phase = 0; + + // QK GEMM + { + uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN); + for (int kt = 0; kt < NKT; kt++) { + // Load Q sub-tile directly + for (int i = tid; i < TILE_SZ; i += NUM_THREADS) sQ0[i] = 0; + for (int d = tid; d < MMA_K; d += NUM_THREADS) { + int full_d = kt * MMA_K + d; + if (full_d < HD) { + int ck = d / 8, lc = d % 8; + sQ0[ck * CORES_MN * 64 + lc] = q[full_d]; + } + } + __syncthreads(); + + // Load K sub-tile via TMA + if (wid == 0 && lane == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, kt * MMA_K, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t)); + } + tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; + __syncthreads(); + + // Convert row-major sTmaBuf → canonical sK0 + for (int i = tid; i < TILE_SZ; i += NUM_THREADS) sK0[i] = 0; + for (int i = tid; i < s_k * MMA_K; i += NUM_THREADS) { + int r = i / MMA_K, c = i % MMA_K; + int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8; + sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i]; + } + __syncthreads(); + + // MMA + if (tid == 0) { + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN); + umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + } + + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // Read S from TMEM + if (wid == 0) { + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) out_s[col] = tmp[c]; + } + } + } + } + __syncthreads(); + if (wid == 0) tmem_dealloc(tb, TMEM_N); +} + +inline bool create_tma_desc_2d_bf16( + CUtensorMap* out, const void* ptr, + uint64_t rows, uint64_t cols, + uint32_t tile_rows, uint32_t tile_cols +) { + uint64_t gd[] = {cols, rows}, gs[] = {cols * 2}; + uint32_t td[] = {tile_cols, tile_rows}, ts[] = {1, 1}; + CUresult r = cuTensorMapEncodeTiled(out, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, + const_cast(ptr), gd, gs, td, ts, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + if (r != CUDA_SUCCESS) { fprintf(stderr, "TMA desc failed: %d\n", (int)r); return false; } + int dv=0; cudaDriverGetVersion(&dv); + if (dv <= 13010 && rows*cols*2 < 131072) reinterpret_cast(out)[1] &= ~(1ULL<<21); + return true; +} + +int main() { + printf("=== TMA + QK GEMM (HD=%d, NUM_THREADS=%d) ===\n", HD, NUM_THREADS); + + bf16_t* h_q = (bf16_t*)calloc(HD, sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t)); + srand(42); + for (int i = 0; i < HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_q, *d_k; float *d_out; + cudaMalloc(&d_q, HD * sizeof(bf16_t)); + cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); + cudaMalloc(&d_out, SK * sizeof(float)); + cudaMemcpy(d_q, h_q, HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + + CUtensorMap tma_k; CUtensorMap* d_tma_k; + create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, MMA_K); + cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); + cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + + size_t smem = 4 + 128 + TILE_SZ*3 + 16 + 8 + 256; + cudaFuncSetAttribute(test_tma_qk_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem); + test_tma_qk_kernel<<<1, NUM_THREADS, smem>>>(d_out, d_q, d_tma_k, SK); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + float* h_out = (float*)malloc(SK * sizeof(float)); + cudaMemcpy(h_out, d_out, SK * sizeof(float), cudaMemcpyDeviceToHost); + + // Reference: raw dot product (MMA is unscaled) + int fail = 0; float max_rel = 0; + for (int j = 0; j < SK; j++) { + float dot = 0; + for (int d = 0; d < HD; d++) dot += bf16_to_f32_host(h_q[d]) * bf16_to_f32_host(h_k[j*HD+d]); + float ref = dot; + float got = h_out[j]; + float rel = fabsf(ref) > 1e-4f ? fabsf(got-ref)/fabsf(ref) : fabsf(got-ref); + if (rel > max_rel) max_rel = rel; + if (rel > 0.01f && fail < 5) printf(" j=%d: ref=%.6f got=%.6f\n", j, ref, got); + if (rel > 0.01f) fail++; + } + printf("Max rel err: %.8f, failures: %d\n", max_rel, fail); + printf("%s\n", fail == 0 ? "PASSED" : "FAILED"); + return fail == 0 ? 0 : 1; +}