diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index 51249a81..ffd491a3 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -1,7 +1,10 @@ /** * UMMA QK GEMM Test (HD=16, SK=128) - * Uses Layout D TMEM read (32x32b.x8, 4 warps) to correctly read MMA output. - * Debugging the 4× scaling factor observed with 16x256b read. + * Using gau-nernst's exact MMA + epilogue pattern: + * - tcgen05.fence::after_thread_sync before TMEM read + * - 32x32b.x8 TMEM read with row = warp_id * 32 + * - tcgen05.wait::ld.sync.aligned after each read + * - Output: off_m + tid per thread (each thread = one row) */ #include @@ -18,61 +21,56 @@ using namespace dsv4::kernels::attention; static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } -__global__ void __launch_bounds__(NTHREADS) +__global__ void __launch_bounds__(128) // 4 warps minimum for Layout D test_umma_qk_hd16(const bf16_t* q, const bf16_t* k, float* s_out, float* s_scalar, float scale) { const int tid = threadIdx.x; - const int wid = tid / WARP, lane = tid % WARP; + const int wid = tid / 32, lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); - bf16_t* sK = sQ + 128 * 16 + 4096; + bf16_t* sK = sQ + 128 * 16 + 4096; // 8KB padding after Q float* sQ_row = (float*)(sK + 128 * 16); - for (int d = tid; d < 16; d += NTHREADS) sQ_row[d] = bf16_to_f32(q[d]); + for (int d = tid; d < 16; d += 128) sQ_row[d] = bf16_to_f32(q[d]); // TMEM alloc (128 cols) - if (wid == 0) { + if (wid == 1) { tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); } __syncthreads(); uint32_t tb = *sTmemBase; - // Load Q and K + // Load Q and K into SMEM in canonical layout write_q_to_smem<16>(sQ, q); write_k_to_smem<128, 16>(sK, k); - // Zero the padding between Q and K - bf16_t* sQ_pad = sQ + 128 * 16; // After Q data - bf16_t* sK_end = sK + 128 * 16; // After K data - for (int i = tid; i < 4096; i += NTHREADS) sQ_pad[i] = 0; // 8KB padding + // Zero padding + bf16_t* sQ_pad = sQ + 128 * 16; + for (int i = tid; i < 4096; i += 128) sQ_pad[i] = 0; __syncthreads(); // Descriptors uint32_t sQ_smem = __cvta_generic_to_shared(sQ); uint32_t sK_smem = __cvta_generic_to_shared(sK); - // Try LBO = 32 (128/4 * 16 / 16 = 32 in 16B units) - // Hypothesis: M=128 has 4 sub-tiles, each with 32 rows - // So LBO should be 32 * 16 = 512 bytes (32 in 16B units) - uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 32); // LBO = 32 * 16 = 512B - uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 32); + uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128); + uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128); uint32_t idesc = make_idesc(128, 128); - // MMA — 1 thread calls (4× scaling is expected for M=128 cta_group::1) + // MMA — 1 thread issues (following gau-nernst pattern) if (tid == 0) { umma_ss_f16(tb, desc_q, desc_k, idesc, false); } - __syncwarp(); - if (wid == 0 && lane == 0) tmem_fence_store(); + // tcgen05.fence::after_thread_sync (CRITICAL — correct MMA→TMEM load fence) + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); - // Read from TMEM using Layout D: 32x32b.x8 format + // Read from TMEM using Layout D: 32x32b.x8 (gau-nernst pattern) // Each warp reads 32 rows × 8 columns - // Warp 0: rows 0-31, warp 1: rows 32-63, warp 2: rows 64-95, warp 3: rows 96-127 - if (wid < 4) { + for (int n = 0; n < 128 / 8; n++) { const int row = wid * 32; - const int col = 0; + const int col = n * 8; const int addr = tb + (row << 16) + col; float tmp[8]; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" @@ -80,10 +78,15 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k, "=f"(tmp[4]), "=f"(tmp[5]), "=f"(tmp[6]), "=f"(tmp[7]) : "r"(addr)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - // Lane 0 gets S[wid*32 + 0, 0..7] in tmp[0..7] - if (lane == 0) { + + // Lane 0 writes first 8 output values for this warp's row + // Each lane gets values for its row (tid = wid*32 + lane) + int out_row = wid * 32 + lane; + if (lane == 0 && n < 1) { // Only first 8 cols for debug for (int c = 0; c < 8; c++) { - s_out[wid * 8 + c] = tmp[c] * 0.25f; // Divide by 4 (Layout D scaling) + if (out_row < 128) { + s_out[out_row * 8 + c] = tmp[c]; + } } } } @@ -100,17 +103,18 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k, } __syncthreads(); + // TMEM dealloc if (wid == 0) tmem_dealloc(tb, 128); } int main() { - printf("=== UMMA QK GEMM Test (Layout D read) ===\n"); + printf("=== UMMA QK GEMM (gau-nernst pattern, 4 warps) ===\n"); const int HD = 16, SK = 128; const float SCALE = 1.0f / sqrtf((float)HD); bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t)); bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t)); - float* h_s_out = (float*)calloc(256, sizeof(float)); + float* h_s_out = (float*)calloc(128*8, sizeof(float)); float* h_s_scalar = (float*)calloc(SK, sizeof(float)); srand(42); @@ -119,33 +123,42 @@ int main() { bf16_t *d_q, *d_k; float *d_s_out, *d_s_scalar; cudaMalloc(&d_q, HD*sizeof(bf16_t)); cudaMalloc(&d_k, SK*HD*sizeof(bf16_t)); - cudaMalloc(&d_s_out, 256*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float)); + cudaMalloc(&d_s_out, 128*8*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float)); cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice); cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice); - cudaMemset(d_s_out, 0, 256*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float)); + cudaMemset(d_s_out, 0, 128*8*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float)); int smem = (4 + 16 + 128*16*2 + 4096*2 + 128*16*2 + 16*4 + 256 + 127) & ~127; - test_umma_qk_hd16<<<1, NTHREADS, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE); + test_umma_qk_hd16<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } - cudaMemcpy(h_s_out, d_s_out, 256*sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(h_s_out, d_s_out, 128*8*sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost); - // Print S[0,0] from each warp's perspective - printf("S from TMEM (Layout D, 32x32b.x8):\n"); - for (int w = 0; w < 4; w++) { - printf(" Warp %d (rows %d-%d): ", w, w*32, w*32+31); - for (int c = 0; c < 8; c++) printf("%.4f ", h_s_out[w*8+c]); - printf("\n"); - } - printf("S[0,0..7] scalar: "); - for (int c = 0; c < 8; c++) printf("%.4f ", h_s_scalar[c]); + // Print S[0,0..7] from each warp's row 0 + printf("Row 0 (MMA): "); + for (int c = 0; c < 8; c++) printf("%.6f ", h_s_out[0*8+c]); + printf("\nRow 0 scalar: "); + for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]); printf("\n"); - printf("Ratio warp0/col0 vs scalar[0]: %.4f\n", - h_s_scalar[0] != 0 ? h_s_out[0] / h_s_scalar[0] : 0); + float max_diff = 0.0f, max_val = 0.0f; + for (int c = 0; c < 8; c++) { + max_diff = fmaxf(max_diff, fabsf(h_s_out[0*8+c] - h_s_scalar[c])); + max_val = fmaxf(max_val, fabsf(h_s_scalar[c])); + } + float rel_err = max_val > 0 ? max_diff / max_val : max_diff; + printf("Row 0 max rel err: %.6f\n", rel_err); + + // Print a few more rows + for (int r : {32, 64, 96}) { + printf("Row %d: %.6f %.6f %.6f %.6f\n", r, + h_s_out[r*8+0], h_s_out[r*8+1], h_s_out[r*8+2], h_s_out[r*8+3]); + } + + printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED"); cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar); free(h_q); free(h_k); free(h_s_out); free(h_s_scalar); return 0;