From 762f054d6dcedf712f2ee522babcc1e3343d9cd6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 03:20:49 +0000 Subject: [PATCH] feat: double-buffer TMA pipeline in multi-row kernel --- .../attention/fmha_6warp_tma_multirow.cuh | 163 ++++++++++-------- tests/unit/test_fmha_6warp_tma_multirow.cu | 15 +- 2 files changed, 101 insertions(+), 77 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh index 5f74472d..8c5e9387 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh @@ -1,27 +1,9 @@ /** - * DSV4 FMHA — 6-warp TMA kernel, multi-row softmax (prefill T>1). + * DSV4 FMHA — 6-warp TMA kernel, multi-row softmax, double-buffer pipeline. * - * Based on fmha_6warp_multirow.cuh with TMA async loads for K. - * Q and V remain direct GMEM loads for now. - * - * ================================================================== - * DESIGN - * ================================================================== - * - * 6-warp CTA: warps 0-3 = softmax, warp 4 = MMA, warp 5 = TMA load. - * Grid: (1, n_h, batch) — each CTA processes one head of one batch item. - * - * TMA pipeline: - * - K: TMA async load via cp.async.bulk.tensor.2d with mbarrier - * - Q: direct GMEM load (multi-row, but small enough for warp-stride) - * - V: direct GMEM load (16×16 sub-tiles) - * - sTmaBuf: staging area for TMA→canonical conversion - * - * Flow: - * 1. QK GEMM: Q direct + K TMA → S in TMEM - * 2. Softmax: 2-pass (row_max, exp+sum+P), P in registers - * 3. PV GEMM: P→sPk + V direct → O in TMEM - * 4. Epilogue: O from TMEM → normalize → BF16 → GMEM + LSE + * Double-buffer TMA for K loads: overlap TMA DMA of sub-tile N+1 with + * MMA compute of sub-tile N. V loaded via TMA (single-buffer, tiny tiles). + * Q loaded directly from GMEM. */ #pragma once @@ -34,13 +16,12 @@ namespace dsv4::kernels::attention { struct FmhaTmaMultiRowParams { const bf16_t* __restrict__ q; - CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K - CUtensorMap* __restrict__ tma_v; // Array of [n_h] TMA descriptors for V - const bf16_t* __restrict__ v; // V: direct fallback (HD, s_k) + CUtensorMap* __restrict__ tma_k; + CUtensorMap* __restrict__ tma_v; + const bf16_t* __restrict__ v; bf16_t* __restrict__ o; float* __restrict__ lse; int s_k, T, n_h; - int n_kv_tiles; // number of KV tiles (s_k / SK_TILE) float scale; int head_dim; int q_head_stride, q_batch_stride; @@ -77,13 +58,15 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { const float scale = params.scale; const bf16_t* __restrict__ q_head = params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride; - // v_head kept for reference; V loaded via TMA - // const bf16_t* __restrict__ v_head = params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride; + // v_head not used — V loaded via TMA bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride; float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr; + CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx; + CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx; + // ================================================================ - // SMEM allocation — 128-byte aligned for TMA + // SMEM allocation — 128-byte aligned // ================================================================ extern __shared__ __align__(128) char sbuf[]; size_t off = 0; @@ -97,13 +80,15 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { off = (off + 127) & ~(size_t)127; bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); off = (off + 127) & ~(size_t)127; + bf16_t* sK1 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); off = (off + 127) & ~(size_t)127; bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t); float* sRowMax = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float); float* sRowSum = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float); - // TMEM alloc + mbarrier init + // Init if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); if (tid == 0) { tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1); @@ -114,22 +99,38 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); int phase = 0; - CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx; - CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx; const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp; const int my_row = my_warp_active ? (wid * 32 + lane) : 0; const bool my_row_active = my_warp_active && (my_row < T); // ================================================================ - // QK GEMM → S in TMEM + // QK GEMM — double-buffer TMA pipeline // ================================================================ - for (int kt = 0; kt < NKT_QK; kt++) { - // Load Q: direct from GMEM, all T rows + { + bf16_t* sK_bufs[2] = {sK0, sK1}; + int cur_buf = 0; + + // Preload kt=0 + if (is_load_warp && lane == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k, mbar_addr, 0, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); + } + tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; + __syncthreads(); + + for (int i = tid; i < TILE_SZ; i += 192) sK0[i] = 0; + for (int i = tid; i < s_k * MMA_K_BF16; i += 192) { + int r = i / MMA_K_BF16, c = i % MMA_K_BF16; + int ck = c/8, lc = c%8, tmn = r/8, lr = r%8; + sK0[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i]; + } + + // Load Q for kt=0 if (is_load_warp) { for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; for (int r = 0; r < T; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { - int full_d = kt * MMA_K_BF16 + d; + int full_d = d; if (full_d < HD) { int ck = d/8, lc = d%8, cm = r/8, lr = r%8; sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q_head[r * HD + full_d]; @@ -137,34 +138,63 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { } } } - - // Load K: TMA async - if (is_load_warp && lane == 0) { - tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k, - mbar_addr, kt * MMA_K_BF16, 0); - tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); - } - tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; __syncthreads(); - // Convert TMA row-major → canonical - for (int i = tid; i < TILE_SZ; i += 192) sK0[i] = 0; - for (int i = tid; i < s_k * MMA_K_BF16; i += 192) { - int r = i / MMA_K_BF16, c = i % MMA_K_BF16; - int ck = c/8, lc = c%8, tmn = r/8, lr = r%8; - sK0[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i]; - } - __syncthreads(); + // Pipeline loop + for (int kt = 0; kt < NKT_QK; kt++) { + bf16_t* cur_sK = sK_bufs[cur_buf]; + bool has_next = (kt + 1 < NKT_QK); - // MMA - if (is_mma_warp) { - uint32_t idesc = make_idesc(128, 128); - uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); - uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128); - if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); - asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + // 1. Issue TMA for kt+1 (DMA runs in parallel with MMA) + if (has_next && is_load_warp && lane == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k, + mbar_addr, (kt+1) * MMA_K_BF16, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); + } + + // 2. MMA QK + if (is_mma_warp) { + uint32_t idesc = make_idesc(128, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(cur_sK), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + + // 3. Wait for TMA kt+1 + if (has_next) { + tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; + } + __syncthreads(); + + // 4. Convert → next buffer + if (has_next) { + bf16_t* next_sK = sK_bufs[1 - cur_buf]; + for (int i = tid; i < TILE_SZ; i += 192) next_sK[i] = 0; + for (int i = tid; i < s_k * MMA_K_BF16; i += 192) { + int r = i / MMA_K_BF16, c = i % MMA_K_BF16; + int ck = c/8, lc = c%8, tmn = r/8, lr = r%8; + next_sK[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i]; + } + } + + // 5. Load Q for kt+1 + if (has_next && is_load_warp) { + for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; + for (int r = 0; r < T; r++) { + for (int d = lane; d < MMA_K_BF16; d += 32) { + int full_d = (kt+1) * MMA_K_BF16 + d; + if (full_d < HD) { + int ck = d/8, lc = d%8, cm = r/8, lr = r%8; + sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q_head[r * HD + full_d]; + } + } + } + } + + cur_buf = 1 - cur_buf; + __syncthreads(); } - __syncthreads(); } asm volatile("fence.sc.gpu;" ::: "memory"); @@ -173,7 +203,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { // ================================================================ // SOFTMAX — 2-pass, P in registers // ================================================================ - // Pass 1: row_max float my_row_max = -INFINITY; if (my_warp_active) { for (int n = 0; n < NUM_READS; n++) { @@ -194,7 +223,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { if (my_row_active) sRowMax[my_row] = my_row_max; __syncthreads(); - // Pass 2: exp + sum + P float my_p_vals[SK_TILE]; float my_row_sum = 0.0f; if (my_warp_active) { @@ -222,20 +250,18 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { __syncthreads(); // ================================================================ - // PV GEMM — P→sPk + V direct → O in TMEM + // PV GEMM — P→sPk + V TMA → O in TMEM // ================================================================ for (int n_sub = 0; n_sub < N_NSUB; n_sub++) { int d_base = n_sub * 16; for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) { const int col_start = pv_kt * MMA_K_BF16; - // Zero sPk if (is_load_warp) { for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0; } __syncthreads(); - // Softmax warps: write P to sPk (all active rows) if (my_row_active) { for (int c = 0; c < MMA_K_BF16; c++) { int gc = col_start + c; @@ -246,7 +272,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { } __syncthreads(); - // Load V sub-tile via TMA + // V via TMA if (is_load_warp && lane == 0) { tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_v, mbar_addr, col_start, d_base); @@ -255,7 +281,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; __syncthreads(); - // Convert sTmaBuf → canonical sV for (int i = tid; i < V_SUB_SZ; i += 192) sV[i] = 0; for (int i = tid; i < 16 * MMA_K_BF16; i += 192) { int dd = i / MMA_K_BF16, lr = i % MMA_K_BF16; @@ -264,7 +289,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { } __syncthreads(); - // MMA if (is_mma_warp) { uint32_t idesc_pv = make_idesc(128, 16); uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128); @@ -280,8 +304,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { __syncthreads(); // ================================================================ - // EPILOGUE — O from TMEM → normalize → GMEM + LSE - // TMEM loads are warp-collective: MUST be outside my_row_active guard + // EPILOGUE // ================================================================ if (my_warp_active) { float rm = my_row_active ? sRowMax[my_row] : 0.0f; diff --git a/tests/unit/test_fmha_6warp_tma_multirow.cu b/tests/unit/test_fmha_6warp_tma_multirow.cu index 8189bdf7..5f20afb6 100644 --- a/tests/unit/test_fmha_6warp_tma_multirow.cu +++ b/tests/unit/test_fmha_6warp_tma_multirow.cu @@ -34,13 +34,14 @@ static size_t compute_smem() { size_t off = 0; off += 4; off = (off+127)&~(size_t)127; off += 16; off = (off+127)&~(size_t)127; - off += TILE_SZ * 2; off = (off+127)&~(size_t)127; - off += TILE_SZ * 2; off = (off+127)&~(size_t)127; - off += TILE_SZ * 2; off = (off+127)&~(size_t)127; - off += TILE_SZ * 2; off = (off+127)&~(size_t)127; - off += 16 * MY_MMA_K * 2; - off += 128 * 4; - off += 128 * 4; + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sTmaBuf + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sQ0 + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sK0 + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sK1 (double buffer) + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sPk + off += 16 * MY_MMA_K * 2; // sV + off += 128 * 4; // sRowMax + off += 128 * 4; // sRowSum return off; }