From c55030a340e7d751c96302fcf6629e5cca1d4e72 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 08:57:00 +0000 Subject: [PATCH] P5: clean kernel with runtime branch (single-tile unchanged, multi-tile separate path) Single-tile path is IDENTICAL to the working pre-P5 kernel. Multi-tile path uses FA2 online softmax with sOacc accumulator. Runtime branch on is_multi_tile = (n_kv_tiles > 1). --- .../attention/fmha_6warp_multihead.cuh | 322 ++++++++++-------- 1 file changed, 182 insertions(+), 140 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_multihead.cuh b/dsv4/kernels/attention/fmha_6warp_multihead.cuh index 9d19e4b8..107fa999 100644 --- a/dsv4/kernels/attention/fmha_6warp_multihead.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multihead.cuh @@ -5,48 +5,26 @@ * MULTI-HEAD + MULTI-KV-TILE (P3 + P5) * ================================================================== * Grid: dim3(1, n_h, batch_size) - * blockIdx.y = head index (0..n_h-1) - * blockIdx.z = batch index (0..batch_size-1) - * * Each CTA processes one head of one batch item independently. - * No cross-CTA synchronization required. * - * For n_kv_tiles=1 (single KV segment): same as before, no rescale. - * For n_kv_tiles>1: FlashAttention-2 online softmax with running max/sum. - * - P is UN-NORMALIZED (exp(s - max), not divided by sum) - * - After each KV tile: read PV result from TMEM → add to sOacc - * - Rescale sOacc when running_max changes: sOacc *= exp(old_max - new_max) - * - Final: O_normalized = sOacc / running_sum + * n_kv_tiles <= 1 (single KV segment, s_k <= 128): + * Direct QK → softmax → PV → epilogue. No rescale. + * Identical to the pre-P5 kernel. + * + * n_kv_tiles > 1 (multiple KV segments, s_k > 128): + * FlashAttention-2 online softmax across KV tiles. + * - P is UN-NORMALIZED: exp(s - max), NOT divided by sum + * - SMEM accumulator sOacc[HD]: O += PV after each tile + * - Rescale: sOacc *= exp(old_max - new_max) when max changes + * - Final: O = sOacc / running_sum * * ================================================================== - * MQA / GQA SUPPORT + * SMEM BUDGET (additional for multi-tile) * ================================================================== - * Same as single-tile. K/V head strides handled by caller. - * For MQA: k_head_stride=0, v_head_stride=0. - * For GQA: K/V expanded via repeat_interleave in Python. - * - * ================================================================== - * SMEM LAYOUT (P5 additions in comments) - * ================================================================== - * sTmemBase: 4 bytes - * sRowMax: 4 bytes (running max across all KV tiles) - * sRowSum: 4 bytes (running sum across all KV tiles) - * padding: ~4 bytes (alignment) - * sQ0: 128 * hd * 2 (Q canonical, reused per KV tile) - * sK0: 128 * 16 * 2 (K segment canonical, one SK_TILE at a time) - * sPk: 128 * 16 * 2 (P sub-tile canonical, reused per PV) - * sV: 16 * 16 * 2 (V sub-tile canonical, reused per PV) - * s_p_vals: SK_TILE * 4 (un-normalized P values) - * sOacc: hd * 4 (float accumulator, 1 row for T=1) - * Total additional vs single-tile: ~hd*4 bytes (sOacc) + * sOacc: HD * 4 bytes (float accumulator, 1 row for T=1) * hd=64: +256B. hd=128: +512B. hd=256: +1024B. - * All well within 232KB SMEM budget. - * - * ================================================================== - * OUTPUT: NORMALIZED O + LSE - * ================================================================== - * O: normalized attention output (divided by running_sum) - * LSE: ln(running_sum) + running_max — for external multi-segment merge if needed + * Total SMEM at hd=64: ~14 KB. hd=128: ~15 KB. hd=256: ~17 KB. + * Well within 232 KB SMEM budget. */ #pragma once @@ -66,7 +44,7 @@ struct FmhaParams { int s_k; // Total KV sequence length float scale; // 1/sqrt(hd) int head_dim; // hd - int n_kv_tiles; // Number of KV tiles (0 = auto-calc from s_k) + int n_kv_tiles; // Number of KV tiles (0 or 1 = single-tile, >1 = multi-tile) int q_head_stride, q_batch_stride; int k_head_stride, k_batch_stride; @@ -95,7 +73,6 @@ fmha_6warp_multihead_kernel(FmhaParams params) { const bool is_mma_warp = (wid == 4); const bool is_load_warp = (wid == 5); - // Per-head GMEM pointers const bf16_t* __restrict__ q_head = params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride; const bf16_t* __restrict__ k_head = params.k @@ -113,23 +90,25 @@ fmha_6warp_multihead_kernel(FmhaParams params) { const float scale = params.scale; const int n_kv_tiles = (params.n_kv_tiles > 0) ? params.n_kv_tiles : (s_k + SK_TILE - 1) / SK_TILE; + const bool is_multi_tile = (n_kv_tiles > 1); // ================================================================ // SMEM allocation // ================================================================ extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - float* sRowMax = (float*)(sbuf + 4); // running max across all tiles - float* sRowSum = sRowMax + 1; // running sum across all tiles + float* sRowMax = (float*)(sbuf + 4); + float* sRowSum = sRowMax + 1; bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + 1) + 15) & ~(uintptr_t)15); bf16_t* sK0 = sQ0 + TILE_SZ; bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127); float* s_p_vals = (float*)(sV + V_SUB_SZ); + // Multi-tile accumulator (only used when n_kv_tiles > 1) float* sOacc = (float*)(((uintptr_t)(s_p_vals + SK_TILE) + 15) & ~(uintptr_t)15); - // Initialize accumulator (single thread to avoid race) - if (tid == 0) { + // Initialize multi-tile accumulator + if (is_multi_tile && tid == 0) { *sRowMax = -INFINITY; *sRowSum = 0.0f; for (int d = 0; d < HD; d++) sOacc[d] = 0.0f; @@ -144,37 +123,28 @@ fmha_6warp_multihead_kernel(FmhaParams params) { uint32_t tb = *sTmemBase; // ================================================================ - // MAIN LOOP: iterate over KV tiles + // SINGLE-TILE PATH (n_kv_tiles <= 1) + // Identical to the pre-P5 kernel. Tested, proven correct. // ================================================================ - for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) { - int kv_start = kv_tile * SK_TILE; - int kv_len = min(SK_TILE, s_k - kv_start); - - // ============================================================ - // QK GEMM: for each K-tile (in hd), load Q + K segment, MMA - // ============================================================ + if (!is_multi_tile) { + // QK GEMM for (int kt = 0; kt < NKT_QK; kt++) { if (is_load_warp) { - // Load Q (same for all KV tiles) for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; for (int d = lane; d < MMA_K_BF16; d += 32) { int ck = d / 8, lc = d % 8; sQ0[ck * 16 * 64 + lc] = q_head[kt * MMA_K_BF16 + d]; } - // Load K segment for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0; - for (int r = 0; r < kv_len; r++) { - int g_r = kv_start + r; + for (int r = 0; r < s_k; r++) { for (int d = lane; d < MMA_K_BF16; d += 32) { int ck = d / 8, lc = d % 8; int tmn = r / 8, lr = r % 8; - sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = - k_head[g_r * HD + kt * MMA_K_BF16 + d]; + sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + kt * MMA_K_BF16 + d]; } } } __syncthreads(); - if (is_mma_warp) { uint32_t idesc = make_idesc(128, 128); uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); @@ -185,10 +155,7 @@ fmha_6warp_multihead_kernel(FmhaParams params) { __syncthreads(); } - // ============================================================ - // Softmax (warp 0, row 0 for T=1 decode) - // P is UN-NORMALIZED for multi-tile: exp(s - max), NOT / sum - // ============================================================ + // Softmax (normalized P for single-tile) if (wid == 0) { float s_vals[SK_TILE], row_max = -INFINITY; for (int n = 0; n < SK_TILE / 8; n++) { @@ -199,102 +166,62 @@ fmha_6warp_multihead_kernel(FmhaParams params) { : "r"(tb + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); if (lane == 0) for (int c=0;c<8;c++) { - float val = tmp[c] * scale; - s_vals[n*8+c] = val; - row_max = fmaxf(row_max, val); + s_vals[n*8+c] = tmp[c] * scale; + row_max = fmaxf(row_max, tmp[c] * scale); } } row_max = wmax(row_max); - + if (lane == 0) *sRowMax = row_max; float row_sum = 0.0f; if (lane == 0) for (int j=0;j -INFINITY) { - float rescale = expf(old_max - new_max); - for (int d = 0; d < HD; d++) sOacc[d] *= rescale; - old_sum *= rescale; - } - *sRowMax = new_max; - } - - // Store P for PV. For single KV tile, use NORMALIZED P (same as old kernel). - // For multi-tile, use UN-NORMALIZED P (critical for FA2 rescale). - // Single-tile: O = P_norm × V. Multi-tile: O = Σ(P_unnorm × V) / running_sum - if (lane == 0) { - if (n_kv_tiles == 1) { - // Normalized P (backward compatible with single-tile) - for (int j=0;j 0); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); } __syncthreads(); } } - // ============================================================ - // Read PV result from TMEM → add to sOacc - // After reading, TMEM can be reused for the next KV tile. - // ============================================================ + // Epilogue: TMEM → regs → BF16 → GMEM (P was normalized, no division needed) if (wid == 0) { + float row_max = *sRowMax; + float row_sum = *sRowSum; + float o_vals[HD]; for (int n = 0; n < HD / 8; n++) { float tmp[8]; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" @@ -302,45 +229,160 @@ fmha_6warp_multihead_kernel(FmhaParams params) { "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - if (lane == 0) for (int c=0;c<8;c++) sOacc[n*8+c] += tmp[c]; + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; + } + if (lane == 0) { + for (int d = 0; d < HD; d++) o_head[d] = f32_to_bf16(o_vals[d]); + if (lse_head) lse_head[0] = logf(row_sum) + row_max; } } __syncthreads(); - } // end KV tile loop - // ================================================================ - // Epilogue: write O to GMEM - // For single KV tile (normalized P): sOacc = P_norm × V (already normalized) - // For multi KV tile (un-normalized P): O = sOacc / running_sum + // MULTI-TILE PATH (n_kv_tiles > 1) + // FlashAttention-2 online softmax across KV tiles. // ================================================================ - if (wid == 0) { - float running_max = *sRowMax; - float running_sum = *sRowSum; + } else { + for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) { + int kv_start = kv_tile * SK_TILE; + int kv_len = min(SK_TILE, s_k - kv_start); - if (lane == 0) { - if (n_kv_tiles > 1) { - float inv_sum = 1.0f / running_sum; - for (int d = 0; d < HD; d++) { - o_head[d] = f32_to_bf16(sOacc[d] * inv_sum); + // QK GEMM (same as single-tile, but K offset = kv_start) + for (int kt = 0; kt < NKT_QK; kt++) { + if (is_load_warp) { + for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; + for (int d = lane; d < MMA_K_BF16; d += 32) { + int ck = d / 8, lc = d % 8; + sQ0[ck * 16 * 64 + lc] = q_head[kt * MMA_K_BF16 + d]; + } + for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0; + for (int r = 0; r < kv_len; r++) { + int g_r = kv_start + r; + for (int d = lane; d < MMA_K_BF16; d += 32) { + int ck = d / 8, lc = d % 8; + int tmn = r / 8, lr = r % 8; + sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k_head[g_r * HD + kt * MMA_K_BF16 + d]; + } + } } - } else { - // Single tile: P was normalized, sOacc is already the correct output - for (int d = 0; d < HD; d++) { - o_head[d] = f32_to_bf16(sOacc[d]); + __syncthreads(); + if (is_mma_warp) { + uint32_t idesc = make_idesc(128, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + + // Softmax: UN-NORMALIZED P (exp(s - max), NOT / sum) + if (wid == 0) { + float s_vals[SK_TILE], row_max = -INFINITY; + for (int n = 0; n < SK_TILE / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) { + s_vals[n*8+c] = tmp[c] * scale; + row_max = fmaxf(row_max, tmp[c] * scale); + } + } + row_max = wmax(row_max); + + float row_sum = 0.0f; + if (lane == 0) for (int j=0;j -INFINITY) { + float rescale = expf(old_max - row_max); + for (int d = 0; d < HD; d++) sOacc[d] *= rescale; + old_sum *= rescale; + } + *sRowMax = row_max; + *sRowSum = old_sum + row_sum; + } + + // Store UN-NORMALIZED P for PV + if (lane == 0) for (int j=0;j