diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu new file mode 100644 index 00000000..5726216a --- /dev/null +++ b/tests/unit/test_umma_qk_hd64.cu @@ -0,0 +1,231 @@ +/** + * UMMA QK GEMM Test — HD=64 (4 K-tiles), SK=128 + * Multi-K-tile accumulate: call MMA 4× with accumulate=true + * Each K-tile: 16 BF16 columns, separate descriptor + */ + +#include +#include +#include +#include +#include + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +// MMA K-tile = 16 BF16 +constexpr int MMA_K = 16; +constexpr int N_WARPS = 4; +constexpr int BLOCK_M = 128; + +__global__ void __launch_bounds__(N_WARPS * 32) +test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, + float* s_out, float* s_scalar, float scale, int hd, int sk) +{ + const int tid = threadIdx.x; + const int wid = tid / 32, lane = tid % 32; + const int n_ktiles = hd / MMA_K; // 4 for hd=64 + + // SMEM: sQ (128, HD) canonical + sK (128, HD) canonical + // Each K-tile of (128, 16) = 4096 bytes + // Full (128, HD) = n_ktiles * 4096 bytes + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK = sQ + 128 * hd; + + // TMEM alloc (128 cols for N=128) + if (wid == 1) { + tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + + // Load Q and K into SMEM in canonical layout + // Using the template with HD as a parameter + // write_q_to_smem and write_k_to_smem need to work with hd=64 + // For now, use explicit loops + // Zero all first + for (int i = tid; i < 128 * hd; i += N_WARPS * 32) { + sQ[i] = 0; + sK[i] = 0; + } + __syncthreads(); + + // Write Q (1, hd) to row 0 of sQ in canonical layout + // Canonical: core(g, c) at offset c * 16 * 64 + g * 64 + local_r * 8 + local_c + for (int d = tid; d < hd; d += N_WARPS * 32) { + int core_k = d / 8, local_c = d % 8; + int idx = core_k * 16 * 64 + local_c; // tile_mn=0, local_r=0 + sQ[idx] = q[d]; + } + + // Write K (sk, hd) to sK in canonical layout + for (int i = tid; i < sk * hd; i += N_WARPS * 32) { + int r = i / hd, c = i % hd; + int tile_mn = r / 8, core_k = c / 8; + int local_r = r % 8, local_c = c % 8; + int idx = core_k * 16 * 64 + tile_mn * 64 + local_r * 8 + local_c; + sK[idx] = k[i]; + } + __syncthreads(); + + // Construct base descriptors for Q and K + uint32_t sQ_smem = __cvta_generic_to_shared(sQ); + uint32_t sK_smem = __cvta_generic_to_shared(sK); + + // For each K-tile, construct a descriptor pointing to that 16-column slice + // K-tile k: columns [16k, 16k+16) of the (128, hd) matrix + // In canonical layout, the k-th 16-column slice starts at: + // k * CORES_MN * 64 = k * 16 * 64 = k * 1024 BF16 = k * 2048 bytes + // Each K-tile has BLOCK_M=128 rows and 16 columns. + // The descriptor for K-tile k: start_addr = sQ_smem + k * 2048 + // But wait — gau-nernst uses A_smem + k * BLOCK_M * 32 for the start address + // BLOCK_M * 32 = 128 * 32 = 4096 bytes. But our K-tile is only 2048 bytes. + // Actually, gau-nernst's offset is for the SMEM start of the K-tile. + // In his layout, each (BLOCK_M, 32B) slice is BLOCK_M * 32 = 4096 bytes apart. + // But 32B = 16 BF16 = one K-tile. And his start_address = A_smem + k * BLOCK_M * 32. + + // Wait — that's 4096 bytes per K-tile, but our (128, 16) matrix is only 4096 bytes. + // In canonical layout, the K-tile at columns [16k, 16k+16) starts at: + // The 2 core-matrix columns (c=2k and c=2k+1) are at offsets 2k*2048 and (2k+1)*2048. + // The full K-tile spans both columns: core(0..15, 2k) and core(0..15, 2k+1). + // The first column starts at 2k * 2048 and the second at (2k+1) * 2048. + // Total span: from 2k*2048 to (2k+1)*2048 + 16*128 = (2k+1)*2048 + 2048 = (2k+2)*2048. + + // Hmm, this is getting complicated. The descriptor for a K-tile should describe + // a (128, 16) matrix starting at the right offset in SMEM. + // For K-tile 0 (columns 0-15): start at sQ_smem, LBO=2048, SBO=128 + // For K-tile 1 (columns 16-31): start at sQ_smem + 2*2048, LBO=2048, SBO=128 + // Wait, but columns 16-23 are core_k=2 and columns 24-31 are core_k=3. + // The K-tile at columns [16k, 16k+16) has core_k = 2k and 2k+1. + // core_k=2 starts at 2 * 1024 = 2048 BF16 = 4096 bytes from sQ. + // core_k=3 starts at 3 * 1024 = 3072 BF16 = 6144 bytes from sQ. + + // But the descriptor's start_address is for the BEGINNING of the (128, 16) tile. + // The descriptor with LBO=2048 walks: column 0 at start, column 1 at start+2048. + // So for K-tile k, we need: + // start = sQ_smem + (2k) * 1024 * 2 = sQ_smem + 2k * 2048 bytes + // Column 0 = core_k 2k, column 1 = core_k 2k+1 + // LBO = 2048 (same as before) + // SBO = 128 (same as before) + + // Actually, this IS the same descriptor but with a different start_address. + // The descriptor for K-tile k: start = sQ + 2k * 2048 + + uint32_t idesc = make_idesc(BLOCK_M, 128); + + // K-tile loop with accumulate + for (int kt = 0; kt < n_ktiles; kt++) { + // Descriptor for Q's k-th K-tile + uint32_t q_ktile_addr = sQ_smem + kt * 2 * 2048; // 2 core-matrix columns per K-tile + uint32_t k_ktile_addr = sK_smem + kt * 2 * 2048; + uint64_t desc_q = make_umma_desc_kmajor_none(q_ktile_addr, BLOCK_M); + uint64_t desc_k = make_umma_desc_kmajor_none(k_ktile_addr, BLOCK_M); + + bool accumulate = (kt > 0); + + // 4 warp leaders call MMA + if (lane == 0) { + umma_ss_f16(tb, desc_q, desc_k, idesc, accumulate); + } + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + + // Read from TMEM: 32x32b.x8, each warp reads 32 rows × 8 columns + for (int n = 0; n < 128 / 8; n++) { + const int row = wid * 32; + const int col = n * 8; + const int addr = tb + (row << 16) + col; + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" + : "=f"(tmp[0]), "=f"(tmp[1]), "=f"(tmp[2]), "=f"(tmp[3]), + "=f"(tmp[4]), "=f"(tmp[5]), "=f"(tmp[6]), "=f"(tmp[7]) + : "r"(addr)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + + int out_row = wid * 32 + lane; + if (n < 2 && out_row < 128) { // First 16 cols + for (int c = 0; c < 8; c++) { + s_out[out_row * 16 + n * 8 + c] = tmp[c] * scale; + } + } + } + __syncthreads(); + + // Scalar reference + if (tid == 0) { + float* q_row = new float[hd]; + for (int d = 0; d < hd; d++) q_row[d] = bf16_to_f32(q[d]); + for (int c = 0; c < sk; c++) { + float dot = 0.0f; + for (int d = 0; d < hd; d++) + dot += q_row[d] * bf16_to_f32(k[c * hd + d]); + s_scalar[c] = dot * scale; + } + delete[] q_row; + } + __syncthreads(); + + if (wid == 0) tmem_dealloc(tb, 128); +} + +int main() { + printf("=== UMMA QK GEMM Test (HD=64, 4 K-tiles) ===\n"); + const int HD = 64, SK = 128; + const float SCALE = 1.0f / sqrtf((float)HD); + + bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t)); + float* h_s_out = (float*)calloc(128 * 16, sizeof(float)); + float* h_s_scalar = (float*)calloc(SK, sizeof(float)); + + srand(42); + for (int d = 0; d < HD; d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < SK*HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_q, *d_k; float *d_s_out, *d_s_scalar; + cudaMalloc(&d_q, HD*sizeof(bf16_t)); cudaMalloc(&d_k, SK*HD*sizeof(bf16_t)); + cudaMalloc(&d_s_out, 128*16*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float)); + cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemset(d_s_out, 0, 128*16*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float)); + + // SMEM: sTmemBase(4) + pad(16) + sQ(128*64*2=16384) + sK(128*64*2=16384) + pad + int smem = (4 + 16 + 128*HD*2 + 128*HD*2 + 256 + 127) & ~127; + printf("SMEM: %d bytes\n", smem); + + test_umma_qk_hd64<<<1, N_WARPS*32, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE, HD, SK); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + cudaMemcpy(h_s_out, d_s_out, 128*16*sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost); + + // Compare row 0 + printf("S[0,0..7] MMA: "); + for (int c = 0; c < 8; c++) printf("%.6f ", h_s_out[0*16+c]); + printf("\nS[0,0..7] ref: "); + for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]); + printf("\n"); + + float max_diff = 0.0f, max_val = 0.0f; + for (int c = 0; c < 16; c++) { + max_diff = fmaxf(max_diff, fabsf(h_s_out[0*16+c] - h_s_scalar[c])); + max_val = fmaxf(max_val, fabsf(h_s_scalar[c])); + } + float rel_err = max_val > 0 ? max_diff / max_val : max_diff; + printf("Row 0 rel err (16 cols): %.6f\n", rel_err); + printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED"); + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar); + free(h_q); free(h_k); free(h_s_out); free(h_s_scalar); + return 0; +}