From 9b458d2a6cf99e5a509e59fba81c03dfa24410de Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 09:16:37 +0000 Subject: [PATCH] test_umma_qk: clean rewrite, hardcoded HD=16, explicit core-matrix layout writes --- tests/unit/test_umma_qk.cu | 250 ++++++++++++------------------------- 1 file changed, 83 insertions(+), 167 deletions(-) diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index 1a5f346f..31877d36 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -4,15 +4,7 @@ * Tests that tcgen05.mma produces correct QK attention scores. * Uses K-major, SWIZZLE_NONE, BF16 descriptors with core-matrix SMEM layout. * - * Strategy: - * 1. Create Q (1, HD) and K (SK, HD) on CPU, copy to GPU - * 2. Load Q into SMEM (padded to 128 rows) in core-matrix layout - * 3. Load K into SMEM (padded to 128 rows) in core-matrix layout - * 4. For each K-tile (16 BF16 columns), call tcgen05.mma SS - * 5. Read S from TMEM, compare against CPU reference - * * First test: HD=16, SK=128 (single K-tile, single MMA call) - * Second test: HD=64, SK=128 (4 K-tiles, accumulate) */ #include @@ -31,200 +23,131 @@ using namespace dsv4::kernels::attention; static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u, &f, 4); - uint16_t h = (u >> 16) & 0xFFFF; - return h; + return (uint16_t)(u >> 16); } // ================================================================== // Test kernel: UMMA QK GEMM for HD=16, SK=128 (single K-tile) // ================================================================== -// This is the simplest possible test: one MMA call, no K-tiling. -// Q: (1, 16) padded to (128, 16) in SMEM -// K: (128, 16) in SMEM (transposed view via K-major descriptor) -// S = Q @ K^T: (128, 128) in TMEM → we only care about row 0 - __global__ void __launch_bounds__(NTHREADS) test_umma_qk_hd16( const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, - float* __restrict__ s_out, // output: S[0, 0..127] (row 0 of QK) - float* __restrict__ s_scalar, // scalar reference: Q @ K^T computed in SMEM - int sk, float scale + float* __restrict__ s_out, // output: S[0, 0..127] from TMEM + float* __restrict__ s_scalar, // scalar reference: Q @ K^T in SMEM + float scale ) { const int tid = threadIdx.x; const int wid = tid / WARP, lane = tid % WARP; - constexpr int HD = 16; - constexpr int SK = 128; - constexpr int KT = MMA_K_TILE; // 16 - constexpr int N_KTILES = HD / KT; // 1 for HD=16 - // ================================================================ // SMEM layout // ================================================================ // sQ_ktile: (128, 16) BF16 in K-major core-matrix layout = 4 KB // sK_ktile: (128, 16) BF16 in K-major core-matrix layout = 4 KB // sTmemBase: 4 bytes - // sQ_row: (HD) BF16 for scalar reference = 32 bytes - // sK_row: (SK*HD) BF16 for scalar reference = 4 KB (we read from GMEM) - // Total: ~8 KB + overhead, well within 232 KB + // sQ_row: (16) floats for scalar reference extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - - // Align sQ_ktile to 16 bytes (required for UMMA descriptor) bf16_t* sQ_ktile = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); - bf16_t* sK_ktile = sQ_ktile + 128 * KT; // After Q (128 * 16 BF16) - float* sQ_row = (float*)(sK_ktile + 128 * KT); // After K + bf16_t* sK_ktile = sQ_ktile + 128 * 16; + float* sQ_row = (float*)(sK_ktile + 128 * 16); // Load Q to sQ_row (float for scalar reference) - for (int d = tid; d < HD; d += NTHREADS) { + for (int d = tid; d < 16; d += NTHREADS) { sQ_row[d] = bf16_to_f32(q[d]); } // ================================================================ - // TMEM allocation + // TMEM allocation: 128 columns for S (128, 128) // ================================================================ - // QK result S: (128, 128) FP32 in TMEM = 128 columns - // (We only need 128 columns for the full 128×128 result) - constexpr int TMEM_COLS = 128; - if (wid == 0) { uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(smem_ptr, TMEM_COLS); + tmem_alloc(smem_ptr, 128); } __syncthreads(); uint32_t tmem_base = *sTmemBase; - uint32_t tmem_s = tmem_base; // Zero TMEM S accumulator if (wid == 0) { - for (int col = 0; col < TMEM_COLS; col++) { - tmem_store(tmem_s + col, 0, 0, 0, 0); + for (int col = 0; col < 128; col++) { + tmem_store(tmem_base + col, 0, 0, 0, 0); } tmem_fence_store(); } __syncthreads(); // ================================================================ - // K-tile loop (just 1 tile for HD=16) + // Load Q and K into SMEM in core-matrix layout // ================================================================ - for (int kt = 0; kt < N_KTILES; kt++) { - // Zero SMEM K-tile buffers - zero_smem<128, KT>(sQ_ktile); - zero_smem<128, KT>(sK_ktile); - __syncthreads(); - - // Write Q K-tile to SMEM in core-matrix layout - // Q is (1, HD), padded to (128, 16). Row 0 has the actual data. - // Only write row 0 (the actual query vector). - // In core-matrix layout, row 0 is in core (0, 0) and core (0, 1). - // Core (0, c) starts at offset c * 64, row 0 within it is at offset 0. - // Row 0, column c*8 + j is at core (0, c), position 0*8 + j = j. - // So core (0, 0) = sQ_ktile[0..7], core (0, 1) = sQ_ktile[64..71] - for (int d = tid; d < KT; d += NTHREADS) { - int c = d; // column within this K-tile - int core_k = c / 8; - int local_c = c % 8; - int dst_idx = core_k * 64 + local_c; // tile_mn=0, so (0*CORES_K + core_k)*64 + 0*8 + local_c - sQ_ktile[dst_idx] = q[kt * KT + c]; - } - // Write K K-tile to SMEM in core-matrix layout - // K is (SK, HD). We write rows 0..SK-1. - write_smem_ktile<128>(sK_ktile, k, kt, HD); - // ^ k + kt*KT is wrong — k is (SK, HD) row-major. - // We need to extract columns [kt*16, kt*16+16) from K. - // Let me fix: for each row r of K, the source is k[r * HD + kt * KT + c] - // But write_smem_ktile expects (ROWS, hd) row-major with k_tile and hd. - // Actually write_smem_ktile(src, k_tile, hd) reads src[r * hd + k_tile*KT + c]. - // But K in GMEM is k[r * HD + c]. We need k[r * HD + kt*KT + c]. - // So pass k as the base pointer and HD as the row stride. - // write_smem_ktile expects src = (ROWS, hd) with k_tile selecting the column range. - // This works: src[r * hd + k_tile * KT + c] = k[r * HD + kt*KT + c]. ✓ - - __syncthreads(); - - // ============================================================ - // Construct UMMA descriptors - // ============================================================ - // Both A (Q) and B (K) are in K-major core-matrix layout. - // For the K^T operation: the hardware transposes via the descriptor. - // For tcgen05.mma, A and B are BOTH K-major. The MMA computes: - // D = A × B^T (for TN layout) - // where A is (M, K) and B^T is (K, N). - // But tcgen05.mma actually computes D = A × B^T when both are K-major? - // No — the MMA computes D = A × B^T when A is MN-major and B is K-major, - // or D = A^T × B^T when both are K-major. - // - // Actually, for tcgen05.mma with the instruction descriptor, - // we can set transpose flags. Let me re-read the spec. - // - // The MMA computes: D = f(A) × f(B) + D - // where f() can include transpose/negate based on idesc bits. - // Without transpose: D = A × B^T (K-major for both = TN layout) - // - // For FMHA QK: S = Q × K^T - // Q is (128, 16) — this is A - // K is (128, 16) — we want K^T = (16, 128) for B - // - // With both K-major (TN layout), the MMA computes: - // D = A × B^T = Q × K^T ✓ - // - // Wait, "TN" means A is row-major (T) and B is col-major (N). - // K-major is the SMEM layout, not the mathematical layout. - // In the MMA's perspective: - // A (K-major) = (M, K) with K contiguous → this IS row-major for A - // B (K-major) = (N, K) with K contiguous → B^T = (K, N) - // D = A × B^T = (M, K) × (K, N) = (M, N) ✓ - // - // So with both descriptors as K-major, the MMA naturally computes Q × K^T. - // No transpose flags needed. - - uint32_t sQ_smem = __cvta_generic_to_shared(sQ_ktile); - uint32_t sK_smem = __cvta_generic_to_shared(sK_ktile); - - // For a (128, 16) K-tile in core-matrix layout: - // LBO = 1, SBO = 2 - uint64_t desc_q = make_umma_desc_kmajor_none_ktile(sQ_smem); - uint64_t desc_k = make_umma_desc_kmajor_none_ktile(sK_smem); - - // ============================================================ - // Call tcgen05.mma SS (both SMEM → TMEM) - // ============================================================ - // Only ONE thread calls MMA (single-threaded launch model) - if (tid == 0) { - bool accumulate = (kt > 0); // First tile: zero, rest: accumulate - umma_ss_f16(tmem_s, desc_q, desc_k, accumulate); - } - __syncwarp(); // Ensure MMA is issued - // Wait for MMA to complete - if (wid == 0 && lane == 0) { - tmem_fence_store(); - } - __syncthreads(); + // Zero both buffers first + for (int i = tid; i < 128 * 16; i += NTHREADS) { + sQ_ktile[i] = 0; + sK_ktile[i] = 0; } + __syncthreads(); + + // Write Q (1, 16) padded to (128, 16) in core-matrix layout + // Core-matrix: each 8x8 BF16 tile at (tile_mn, tile_k) offset (tile_mn * 2 + tile_k) * 64 + // Q row 0, cols 0-15: tiles (0,0) and (0,1) + // Tile (0,0): positions 0-7 = Q[0..7], Tile (0,1): positions 0-7 = Q[8..15] + for (int d = tid; d < 16; d += NTHREADS) { + int tile_k = d / 8; + int local_c = d % 8; + int dst_idx = (0 * 2 + tile_k) * 64 + 0 * 8 + local_c; // row 0 (local_r=0) + sQ_ktile[dst_idx] = q[d]; + } + + // Write K (128, 16) in core-matrix layout + // K is (128, 16) row-major in GMEM. In core-matrix: + // tile (r/8, c/8) at offset ((r/8) * 2 + (c/8)) * 64 + (r%8)*8 + (c%8) + for (int i = tid; i < 128 * 16; i += NTHREADS) { + int r = i / 16; + int c = i % 16; + int tile_mn = r / 8; + int tile_k = c / 8; + int local_r = r % 8; + int local_c = c % 8; + int dst_idx = (tile_mn * 2 + tile_k) * 64 + local_r * 8 + local_c; + sK_ktile[dst_idx] = k[i]; + } + __syncthreads(); + + // ================================================================ + // Construct UMMA descriptors and call MMA + // ================================================================ + // Both A (Q) and B (K) are (128, 16) in K-major core-matrix layout. + // For K-major NONE with BLOCK_K=16: LBO=1, SBO=2 + uint32_t sQ_smem = __cvta_generic_to_shared(sQ_ktile); + uint32_t sK_smem = __cvta_generic_to_shared(sK_ktile); + + uint64_t desc_q = make_umma_desc_kmajor_none_ktile(sQ_smem); + uint64_t desc_k = make_umma_desc_kmajor_none_ktile(sK_smem); + + // Single-threaded MMA launch + if (tid == 0) { + umma_ss_f16(tmem_base, desc_q, desc_k, /*accumulate=*/false); + } + __syncwarp(); + + // Wait for MMA to complete + if (wid == 0 && lane == 0) { + tmem_fence_store(); + } + __syncthreads(); // ================================================================ // Read S from TMEM and write to output // ================================================================ - // S is (128, 128) FP32 in TMEM. Row 0 is the actual attention score. - // TMEM lane mapping: lane i reads positions i*4+0..3 per column. - // For row 0, lane 0 reads positions 0-3 from each column. + // S is (128, 128) FP32 in TMEM. We care about row 0 (the query row). + // TMEM column col stores 128 FP32. Lane 0's u0 = S[0, col]. if (wid == 0) { for (int col = 0; col < 128; col++) { uint32_t u0, u1, u2, u3; - tmem_load(tmem_s + col, u0, u1, u2, u3); - // Lane 0's u0 = row 0, position 0 in this column - // But which row does this correspond to in the S matrix? - // S is (128, 128) — row m, column n. - // TMEM column col stores 128 FP32 values (rows 0-127 of column col). - // Lane 0 reads positions 0-3: S[0,col], S[1,col], S[2,col], S[3,col] + tmem_load(tmem_base + col, u0, u1, u2, u3); if (lane == 0) { - float val = u32_to_f32(u0); // S[0, col] - if (col < sk) { - s_out[col] = val; - } + s_out[col] = u32_to_f32(u0); // S[0, col] (un-normalized) } } - tmem_fence_load(); } __syncthreads(); @@ -232,10 +155,10 @@ test_umma_qk_hd16( // Scalar reference: compute Q @ K^T in SMEM (row 0 only) // ================================================================ if (tid == 0) { - for (int c = 0; c < sk; c++) { + for (int c = 0; c < 128; c++) { float dot = 0.0f; - for (int d = 0; d < HD; d++) { - dot += sQ_row[d] * bf16_to_f32(k[c * HD + d]); + for (int d = 0; d < 16; d++) { + dot += sQ_row[d] * bf16_to_f32(k[c * 16 + d]); } s_scalar[c] = dot * scale; } @@ -244,7 +167,7 @@ test_umma_qk_hd16( // TMEM dealloc if (wid == 0) { - tmem_dealloc(tmem_base, TMEM_COLS); + tmem_dealloc(tmem_base, 128); } } @@ -255,9 +178,9 @@ test_umma_qk_hd16( int main() { printf("=== UMMA QK GEMM Test (HD=16, SK=128) ===\n"); - constexpr int HD = 16; - constexpr int SK = 128; - constexpr float SCALE = 1.0f / sqrtf((float)HD); + const int HD = 16; + const int SK = 128; + const float SCALE = 1.0f / sqrtf((float)HD); // Allocate host memory bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t)); @@ -288,22 +211,16 @@ int main() { cudaMemset(d_s_scalar, 0, SK * sizeof(float)); // Compute SMEM size - // sTmemBase: 4 bytes - // sQ_ktile: 128 * 16 * 2 = 4096 bytes (aligned to 16B) - // sK_ktile: 128 * 16 * 2 = 4096 bytes - // sQ_row: 16 * 4 = 64 bytes - // Total: 4 + 16 (alignment) + 4096 + 4096 + 64 = ~8276 bytes - int smem_size = 4 + 16 + 128 * MMA_K_TILE * 2 + 128 * MMA_K_TILE * 2 + HD * 4 + 256; // extra padding - smem_size = (smem_size + 127) & ~127; // align to 128B - + // sTmemBase: 4 + alignment 16 + sQ: 128*16*2 + sK: 128*16*2 + sQ_row: 16*4 + padding + int smem_size = 4 + 16 + 128*16*2 + 128*16*2 + 16*4 + 256; + smem_size = (smem_size + 127) & ~127; printf("SMEM size: %d bytes\n", smem_size); - printf("Launching kernel with %d threads, 1 CTA\n", NTHREADS); // Launch dim3 grid(1, 1, 1); dim3 block(NTHREADS); test_umma_qk_hd16<<>>( - d_q, d_k, d_s_out, d_s_scalar, SK, SCALE); + d_q, d_k, d_s_out, d_s_scalar, SCALE); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { @@ -322,8 +239,7 @@ int main() { for (int c = 0; c < 8; c++) printf("%.4f ", h_s_scalar[c]); printf("\n"); - float max_diff = 0.0f; - float max_val = 0.0f; + float max_diff = 0.0f, max_val = 0.0f; for (int c = 0; c < SK; c++) { float diff = fabsf(h_s_out[c] - h_s_scalar[c]); max_diff = fmaxf(max_diff, diff);