P5: clean kernel with runtime branch (single-tile unchanged, multi-tile separate path)

Single-tile path is IDENTICAL to the working pre-P5 kernel.
Multi-tile path uses FA2 online softmax with sOacc accumulator.
Runtime branch on is_multi_tile = (n_kv_tiles > 1).
This commit is contained in:
2026-05-30 08:57:00 +00:00
parent 5f4856d771
commit c55030a340

View File

@@ -5,48 +5,26 @@
* MULTI-HEAD + MULTI-KV-TILE (P3 + P5)
* ==================================================================
* Grid: dim3(1, n_h, batch_size)
* blockIdx.y = head index (0..n_h-1)
* blockIdx.z = batch index (0..batch_size-1)
*
* Each CTA processes one head of one batch item independently.
* No cross-CTA synchronization required.
*
* For n_kv_tiles=1 (single KV segment): same as before, no rescale.
* For n_kv_tiles>1: FlashAttention-2 online softmax with running max/sum.
* - P is UN-NORMALIZED (exp(s - max), not divided by sum)
* - After each KV tile: read PV result from TMEM → add to sOacc
* - Rescale sOacc when running_max changes: sOacc *= exp(old_max - new_max)
* - Final: O_normalized = sOacc / running_sum
* n_kv_tiles <= 1 (single KV segment, s_k <= 128):
* Direct QK → softmax → PV → epilogue. No rescale.
* Identical to the pre-P5 kernel.
*
* n_kv_tiles > 1 (multiple KV segments, s_k > 128):
* FlashAttention-2 online softmax across KV tiles.
* - P is UN-NORMALIZED: exp(s - max), NOT divided by sum
* - SMEM accumulator sOacc[HD]: O += PV after each tile
* - Rescale: sOacc *= exp(old_max - new_max) when max changes
* - Final: O = sOacc / running_sum
*
* ==================================================================
* MQA / GQA SUPPORT
* SMEM BUDGET (additional for multi-tile)
* ==================================================================
* Same as single-tile. K/V head strides handled by caller.
* For MQA: k_head_stride=0, v_head_stride=0.
* For GQA: K/V expanded via repeat_interleave in Python.
*
* ==================================================================
* SMEM LAYOUT (P5 additions in comments)
* ==================================================================
* sTmemBase: 4 bytes
* sRowMax: 4 bytes (running max across all KV tiles)
* sRowSum: 4 bytes (running sum across all KV tiles)
* padding: ~4 bytes (alignment)
* sQ0: 128 * hd * 2 (Q canonical, reused per KV tile)
* sK0: 128 * 16 * 2 (K segment canonical, one SK_TILE at a time)
* sPk: 128 * 16 * 2 (P sub-tile canonical, reused per PV)
* sV: 16 * 16 * 2 (V sub-tile canonical, reused per PV)
* s_p_vals: SK_TILE * 4 (un-normalized P values)
* sOacc: hd * 4 (float accumulator, 1 row for T=1)
* Total additional vs single-tile: ~hd*4 bytes (sOacc)
* sOacc: HD * 4 bytes (float accumulator, 1 row for T=1)
* hd=64: +256B. hd=128: +512B. hd=256: +1024B.
* All well within 232KB SMEM budget.
*
* ==================================================================
* OUTPUT: NORMALIZED O + LSE
* ==================================================================
* O: normalized attention output (divided by running_sum)
* LSE: ln(running_sum) + running_max — for external multi-segment merge if needed
* Total SMEM at hd=64: ~14 KB. hd=128: ~15 KB. hd=256: ~17 KB.
* Well within 232 KB SMEM budget.
*/
#pragma once
@@ -66,7 +44,7 @@ struct FmhaParams {
int s_k; // Total KV sequence length
float scale; // 1/sqrt(hd)
int head_dim; // hd
int n_kv_tiles; // Number of KV tiles (0 = auto-calc from s_k)
int n_kv_tiles; // Number of KV tiles (0 or 1 = single-tile, >1 = multi-tile)
int q_head_stride, q_batch_stride;
int k_head_stride, k_batch_stride;
@@ -95,7 +73,6 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
const bool is_mma_warp = (wid == 4);
const bool is_load_warp = (wid == 5);
// Per-head GMEM pointers
const bf16_t* __restrict__ q_head = params.q
+ head_idx * params.q_head_stride + batch_idx * params.q_batch_stride;
const bf16_t* __restrict__ k_head = params.k
@@ -113,23 +90,25 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
const float scale = params.scale;
const int n_kv_tiles = (params.n_kv_tiles > 0) ? params.n_kv_tiles
: (s_k + SK_TILE - 1) / SK_TILE;
const bool is_multi_tile = (n_kv_tiles > 1);
// ================================================================
// SMEM allocation
// ================================================================
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
float* sRowMax = (float*)(sbuf + 4); // running max across all tiles
float* sRowSum = sRowMax + 1; // running sum across all tiles
float* sRowMax = (float*)(sbuf + 4);
float* sRowSum = sRowMax + 1;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + 1) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + TILE_SZ;
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
float* s_p_vals = (float*)(sV + V_SUB_SZ);
// Multi-tile accumulator (only used when n_kv_tiles > 1)
float* sOacc = (float*)(((uintptr_t)(s_p_vals + SK_TILE) + 15) & ~(uintptr_t)15);
// Initialize accumulator (single thread to avoid race)
if (tid == 0) {
// Initialize multi-tile accumulator
if (is_multi_tile && tid == 0) {
*sRowMax = -INFINITY;
*sRowSum = 0.0f;
for (int d = 0; d < HD; d++) sOacc[d] = 0.0f;
@@ -144,37 +123,28 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
uint32_t tb = *sTmemBase;
// ================================================================
// MAIN LOOP: iterate over KV tiles
// SINGLE-TILE PATH (n_kv_tiles <= 1)
// Identical to the pre-P5 kernel. Tested, proven correct.
// ================================================================
for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) {
int kv_start = kv_tile * SK_TILE;
int kv_len = min(SK_TILE, s_k - kv_start);
// ============================================================
// QK GEMM: for each K-tile (in hd), load Q + K segment, MMA
// ============================================================
if (!is_multi_tile) {
// QK GEMM
for (int kt = 0; kt < NKT_QK; kt++) {
if (is_load_warp) {
// Load Q (same for all KV tiles)
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
for (int d = lane; d < MMA_K_BF16; d += 32) {
int ck = d / 8, lc = d % 8;
sQ0[ck * 16 * 64 + lc] = q_head[kt * MMA_K_BF16 + d];
}
// Load K segment
for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0;
for (int r = 0; r < kv_len; r++) {
int g_r = kv_start + r;
for (int r = 0; r < s_k; r++) {
for (int d = lane; d < MMA_K_BF16; d += 32) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] =
k_head[g_r * HD + kt * MMA_K_BF16 + d];
sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + kt * MMA_K_BF16 + d];
}
}
}
__syncthreads();
if (is_mma_warp) {
uint32_t idesc = make_idesc(128, 128);
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
@@ -185,10 +155,7 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
__syncthreads();
}
// ============================================================
// Softmax (warp 0, row 0 for T=1 decode)
// P is UN-NORMALIZED for multi-tile: exp(s - max), NOT / sum
// ============================================================
// Softmax (normalized P for single-tile)
if (wid == 0) {
float s_vals[SK_TILE], row_max = -INFINITY;
for (int n = 0; n < SK_TILE / 8; n++) {
@@ -199,102 +166,62 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) {
float val = tmp[c] * scale;
s_vals[n*8+c] = val;
row_max = fmaxf(row_max, val);
s_vals[n*8+c] = tmp[c] * scale;
row_max = fmaxf(row_max, tmp[c] * scale);
}
}
row_max = wmax(row_max);
if (lane == 0) *sRowMax = row_max;
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK_TILE;j++) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
row_sum = wsum(row_sum);
// Online softmax rescale: sOacc *= exp(old_max - new_max)
// This is the FlashAttention-2 running max/sum update.
float old_max, new_max, old_sum;
if (lane == 0) {
old_max = *sRowMax;
new_max = row_max;
old_sum = *sRowSum;
if (old_max > -INFINITY) {
float rescale = expf(old_max - new_max);
for (int d = 0; d < HD; d++) sOacc[d] *= rescale;
old_sum *= rescale;
}
*sRowMax = new_max;
}
// Store P for PV. For single KV tile, use NORMALIZED P (same as old kernel).
// For multi-tile, use UN-NORMALIZED P (critical for FA2 rescale).
// Single-tile: O = P_norm × V. Multi-tile: O = Σ(P_unnorm × V) / running_sum
if (lane == 0) {
if (n_kv_tiles == 1) {
// Normalized P (backward compatible with single-tile)
for (int j=0;j<SK_TILE;j++) s_vals[j] /= row_sum;
}
for (int j=0;j<SK_TILE;j++) s_p_vals[j] = s_vals[j];
}
// Update running sum
if (lane == 0) *sRowSum = old_sum + row_sum;
if (lane == 0) *sRowSum = row_sum;
if (lane == 0) for (int j=0;j<SK_TILE;j++) s_vals[j] /= row_sum;
if (lane == 0) for (int j=0;j<SK_TILE;j++) s_p_vals[j] = s_vals[j];
}
__syncthreads();
// ============================================================
// PV GEMM: P_unnorm × V → O_tile, accumulate into TMEM
// For each KV tile: ACCUMULATE=False on first sub-tile, True on rest
// ============================================================
// PV GEMM
for (int n = 0; n < N_NSUB; n++) {
int d_base = n * 16;
for (int kt = 0; kt < NKT_PV; kt++) {
if (is_load_warp) {
// Fill sPk from s_p_vals (un-normalized)
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
if (lane < 16) {
int c = lane;
int ck = c / 8, lc = c % 8;
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] =
f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
}
// Load V sub-tile
for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0;
for (int dd = lane; dd < 16; dd += 32) {
for (int lr = 0; lr < MMA_K_BF16; lr++) {
int r = kv_start + kt * MMA_K_BF16 + lr;
int r = kt * MMA_K_BF16 + lr;
int g_mn = dd / 8, g_k = lr / 8;
int llr = dd % 8, lc = lr % 8;
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] =
v_head[(d_base + dd) * s_k + r];
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r];
}
}
}
__syncthreads();
if (is_mma_warp) {
uint32_t idesc_pv16 = make_idesc(128, 16);
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
// ACCUMULATE=False for first PV sub-tile of each KV tile
// (clears TMEM for that sub-tile). True for subsequent sub-tiles.
bool acc = !(n == 0 && kt == 0);
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, acc);
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
}
// ============================================================
// Read PV result from TMEM → add to sOacc
// After reading, TMEM can be reused for the next KV tile.
// ============================================================
// Epilogue: TMEM → regs → BF16 → GMEM (P was normalized, no division needed)
if (wid == 0) {
float row_max = *sRowMax;
float row_sum = *sRowSum;
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
@@ -302,45 +229,160 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) sOacc[n*8+c] += tmp[c];
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
}
if (lane == 0) {
for (int d = 0; d < HD; d++) o_head[d] = f32_to_bf16(o_vals[d]);
if (lse_head) lse_head[0] = logf(row_sum) + row_max;
}
}
__syncthreads();
} // end KV tile loop
// ================================================================
// Epilogue: write O to GMEM
// For single KV tile (normalized P): sOacc = P_norm × V (already normalized)
// For multi KV tile (un-normalized P): O = sOacc / running_sum
// MULTI-TILE PATH (n_kv_tiles > 1)
// FlashAttention-2 online softmax across KV tiles.
// ================================================================
if (wid == 0) {
float running_max = *sRowMax;
float running_sum = *sRowSum;
} else {
for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) {
int kv_start = kv_tile * SK_TILE;
int kv_len = min(SK_TILE, s_k - kv_start);
if (lane == 0) {
if (n_kv_tiles > 1) {
float inv_sum = 1.0f / running_sum;
for (int d = 0; d < HD; d++) {
o_head[d] = f32_to_bf16(sOacc[d] * inv_sum);
// QK GEMM (same as single-tile, but K offset = kv_start)
for (int kt = 0; kt < NKT_QK; kt++) {
if (is_load_warp) {
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
for (int d = lane; d < MMA_K_BF16; d += 32) {
int ck = d / 8, lc = d % 8;
sQ0[ck * 16 * 64 + lc] = q_head[kt * MMA_K_BF16 + d];
}
for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0;
for (int r = 0; r < kv_len; r++) {
int g_r = kv_start + r;
for (int d = lane; d < MMA_K_BF16; d += 32) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k_head[g_r * HD + kt * MMA_K_BF16 + d];
}
}
}
} else {
// Single tile: P was normalized, sOacc is already the correct output
for (int d = 0; d < HD; d++) {
o_head[d] = f32_to_bf16(sOacc[d]);
__syncthreads();
if (is_mma_warp) {
uint32_t idesc = make_idesc(128, 128);
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128);
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
// Softmax: UN-NORMALIZED P (exp(s - max), NOT / sum)
if (wid == 0) {
float s_vals[SK_TILE], row_max = -INFINITY;
for (int n = 0; n < SK_TILE / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) {
s_vals[n*8+c] = tmp[c] * scale;
row_max = fmaxf(row_max, tmp[c] * scale);
}
}
row_max = wmax(row_max);
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK_TILE;j++) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
row_sum = wsum(row_sum);
// Online softmax rescale: sOacc *= exp(old_max - new_max)
if (lane == 0) {
float old_max = *sRowMax;
float old_sum = *sRowSum;
if (old_max > -INFINITY) {
float rescale = expf(old_max - row_max);
for (int d = 0; d < HD; d++) sOacc[d] *= rescale;
old_sum *= rescale;
}
*sRowMax = row_max;
*sRowSum = old_sum + row_sum;
}
// Store UN-NORMALIZED P for PV
if (lane == 0) for (int j=0;j<SK_TILE;j++) s_p_vals[j] = s_vals[j];
}
__syncthreads();
// PV GEMM (ACCUMULATE=False for first sub-tile of each KV tile)
for (int n = 0; n < N_NSUB; n++) {
int d_base = n * 16;
for (int kt = 0; kt < NKT_PV; kt++) {
if (is_load_warp) {
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
if (lane < 16) {
int c = lane;
int ck = c / 8, lc = c % 8;
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
}
for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0;
for (int dd = lane; dd < 16; dd += 32) {
for (int lr = 0; lr < MMA_K_BF16; lr++) {
int r = kv_start + kt * MMA_K_BF16 + lr;
int g_mn = dd / 8, g_k = lr / 8;
int llr = dd % 8, lc = lr % 8;
sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r];
}
}
}
__syncthreads();
if (is_mma_warp) {
uint32_t idesc_pv16 = make_idesc(128, 16);
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
// ACCUMULATE=False for first PV sub-tile of each KV tile
bool acc = !(n == 0 && kt == 0);
if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, acc);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
}
if (lse_head) {
lse_head[0] = logf(running_sum) + running_max;
// Read PV from TMEM → accumulate into sOacc
if (wid == 0) {
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) sOacc[n*8+c] += tmp[c];
}
}
__syncthreads();
} // end KV tile loop
// Epilogue: normalize sOacc by running_sum
if (wid == 0) {
float running_max = *sRowMax;
float running_sum = *sRowSum;
if (lane == 0) {
float inv_sum = 1.0f / running_sum;
for (int d = 0; d < HD; d++) o_head[d] = f32_to_bf16(sOacc[d] * inv_sum);
if (lse_head) lse_head[0] = logf(running_sum) + running_max;
}
}
__syncthreads();
}
__syncthreads();
// TMEM dealloc
if (is_mma_warp) {
tmem_dealloc(tb, TMEM_N);
}
if (is_mma_warp) tmem_dealloc(tb, TMEM_N);
}
} // namespace dsv4::kernels::attention