diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 54fcddc2..d54dfe2b 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -1,15 +1,8 @@ /** - * UMMA QK GEMM Test — HD=64 (4 K-tiles) + * UMMA QK GEMM Test — HD=64 (4 K-tiles, separate SMEM per K-tile) * - * Full multi-K-tile QK GEMM with proper SMEM writes. - * Key fix: source data stride ≠ SMEM tile width — must write manually. - * - * Pipeline: - * 1. Load Q (1, 64) into (128, 64) canonical SMEM - * 2. Load K (128, 64) into (128, 64) canonical SMEM - * 3. For each K-tile (16 BF16): construct offset descriptor, call MMA with accumulate - * 4. Read S from TMEM, apply 1/sqrt(HD) scale - * 5. Compare against scalar reference + * Each K-tile gets its own (128, 16) SMEM region — no offset descriptors. + * Source data stride handled correctly (SRC_HD=64, SMEM_HD=16). */ #include @@ -29,43 +22,7 @@ constexpr int HD = 64; constexpr int SK = 128; constexpr int NKT = HD / MMA_K_BF16; // 4 constexpr int BLOCK_MN = 128; - -/** - * Write Q (1, SRC_HD) into (128, SMEM_HD) canonical layout. - * Only row 0 has data. Source stride = SRC_HD, SMEM cols = SMEM_HD. - */ -template -__device__ void write_q_canonical(bf16_t* dst, const bf16_t* q) { - constexpr int CORES_MN = 128 / 8; // 16 - constexpr int CORES_K = SMEM_HD / 8; - // Zero all - for (int i = threadIdx.x; i < 128 * SMEM_HD; i += 128) dst[i] = 0; - // Row 0 only: core_mn=0, local_r=0 - for (int c = threadIdx.x; c < SRC_HD; c += 128) { - int ck = c / 8, lc = c % 8; - dst[ck * CORES_MN * 64 + lc] = q[c]; - } -} - -/** - * Write K (SK, SRC_HD) into (128, SMEM_HD) canonical layout. - * Source stride = SRC_HD, SMEM cols = SMEM_HD. - */ -template -__device__ void write_k_canonical(bf16_t* dst, const bf16_t* k) { - constexpr int CORES_MN = 128 / 8; // 16 - // Zero all - for (int i = threadIdx.x; i < 128 * SMEM_HD; i += 128) dst[i] = 0; - // Write actual rows - for (int i = threadIdx.x; i < SK_VAL * SMEM_HD; i += 128) { - int r = i / SMEM_HD; - int c = i % SMEM_HD; - if (r >= SK_VAL || c >= SRC_HD) continue; - int ck = c / 8, lc = c % 8; - int tmn = r / 8, lr = r % 8; - dst[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * SRC_HD + c]; - } -} +constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; // 128*16 = 2048 BF16 per K-tile __global__ void __launch_bounds__(128) test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, @@ -74,52 +31,70 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, const int tid = threadIdx.x; const int wid = tid / WARP, lane = tid % WARP; + // SMEM: tmem_base(4) + pad(12) + Q tiles (4 × 2048 BF16) + K tiles (4 × 2048 BF16) extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); - bf16_t* sK = sQ + 128 * HD; // (128, 64) each = 16384 bytes + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sQ1 = sQ0 + TILE_SZ; + bf16_t* sQ2 = sQ1 + TILE_SZ; + bf16_t* sQ3 = sQ2 + TILE_SZ; + bf16_t* sK0 = sQ3 + TILE_SZ; + bf16_t* sK1 = sK0 + TILE_SZ; + bf16_t* sK2 = sK1 + TILE_SZ; + bf16_t* sK3 = sK2 + TILE_SZ; - // Load Q (1, 64) → (128, 64) canonical - write_q_canonical(sQ, q); - // Load K (128, 64) → (128, 64) canonical - write_k_canonical(sK, k); + constexpr int CORES_MN = 16; // 128/8 + + // Load Q K-tiles: Q is (1, 64), each K-tile takes 16 dims + // Zero all tiles + for (int i = tid; i < NKT * TILE_SZ; i += 128) { sQ0[i] = 0; sK0[i] = 0; } __syncthreads(); - // TMEM alloc — 128 columns for (128, 128) Layout D output + // Write Q row 0 to each K-tile's SMEM + for (int kt = 0; kt < NKT; kt++) { + bf16_t* sq = sQ0 + kt * TILE_SZ; + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + sq[ck * CORES_MN * 64 + lc] = q[kt * MMA_K_BF16 + d]; + } + // Write K for this K-tile: K[r, 16*kt + d] for r=0..127, d=0..15 + bf16_t* sk = sK0 + kt * TILE_SZ; + for (int r = 0; r < SK; r++) { + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + int tmn = r / 8, lr = r % 8; + sk[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d]; + } + } + } + __syncthreads(); + + // TMEM alloc if (wid == 1) { tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); } __syncthreads(); uint32_t tb = *sTmemBase; - // Multi-K-tile QK GEMM - uint32_t sQ_smem = __cvta_generic_to_shared(sQ); - uint32_t sK_smem = __cvta_generic_to_shared(sK); + // Multi-K-tile QK GEMM with separate SMEM per K-tile + bf16_t* sQ_arr[NKT] = {sQ0, sQ1, sQ2, sQ3}; + bf16_t* sK_arr[NKT] = {sK0, sK1, sK2, sK3}; uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN); - for (int kt = 0; kt < 1; kt++) { // DEBUG: single K-tile from full SMEM - // K-tile offset in canonical layout: - // Each 16-BF16 K-tile spans 2 core columns. - // Core column 2*kt starts at offset 2*kt * (128/8 * 128) bytes = 2*kt * 2048 bytes = kt * 4096 bytes. - uint32_t q_addr = sQ_smem + kt * BLOCK_MN * 32; - uint32_t k_addr = sK_smem + kt * BLOCK_MN * 32; + for (int kt = 0; kt < NKT; kt++) { + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ_arr[kt]), BLOCK_MN); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK_arr[kt]), BLOCK_MN); - uint64_t dq = make_umma_desc_kmajor_none(q_addr, BLOCK_MN); - uint64_t dk = make_umma_desc_kmajor_none(k_addr, BLOCK_MN); - - // Single thread calls MMA (gau-nernst's elect_one pattern) if (tid == 0) { umma_ss_f16(tb, dq, dk, idesc, kt > 0); } asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); } - - // Final fence before TMEM read asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); - // Read S from TMEM (Layout D: 32x32b.x8) + // Read S from TMEM for (int n = 0; n < 128 / 8; n++) { const int row = wid * 32; const int col = n * 8; @@ -144,7 +119,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, if (tid == 0) { for (int j = 0; j < SK; j++) { float dot = 0.0f; - for (int d = 0; d < 16; d++) // DEBUG: single K-tile + for (int d = 0; d < HD; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * HD + d]); s_scalar[j] = dot * scale; } @@ -154,7 +129,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, } int main() { - printf("=== UMMA QK GEMM HD=64 (4 K-tiles, fixed stride) ===\n"); + printf("=== UMMA QK GEMM HD=64 (separate SMEM per K-tile) ===\n"); const float SCALE = 1.0f / sqrtf((float)HD); bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t)); @@ -172,7 +147,8 @@ int main() { cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice); cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice); - int smem = (4 + 16 + 2 * 128 * HD * sizeof(bf16_t) + 256 + 127) & ~127; + // SMEM: 4 + 12(pad) + 8 * 2048*2 + int smem = (4 + 16 + NKT * 2 * TILE_SZ * sizeof(bf16_t) + 256 + 127) & ~127; printf("SMEM: %d bytes (%d KB)\n", smem, smem / 1024); test_umma_hd64<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);