feat: double-buffer TMA pipeline in multi-row kernel

This commit is contained in:
2026-05-30 03:20:49 +00:00
parent 4a9c850e9c
commit 762f054d6d
2 changed files with 101 additions and 77 deletions

View File

@@ -1,27 +1,9 @@
/**
* DSV4 FMHA — 6-warp TMA kernel, multi-row softmax (prefill T>1).
* DSV4 FMHA — 6-warp TMA kernel, multi-row softmax, double-buffer pipeline.
*
* Based on fmha_6warp_multirow.cuh with TMA async loads for K.
* Q and V remain direct GMEM loads for now.
*
* ==================================================================
* DESIGN
* ==================================================================
*
* 6-warp CTA: warps 0-3 = softmax, warp 4 = MMA, warp 5 = TMA load.
* Grid: (1, n_h, batch) — each CTA processes one head of one batch item.
*
* TMA pipeline:
* - K: TMA async load via cp.async.bulk.tensor.2d with mbarrier
* - Q: direct GMEM load (multi-row, but small enough for warp-stride)
* - V: direct GMEM load (16×16 sub-tiles)
* - sTmaBuf: staging area for TMA→canonical conversion
*
* Flow:
* 1. QK GEMM: Q direct + K TMA → S in TMEM
* 2. Softmax: 2-pass (row_max, exp+sum+P), P in registers
* 3. PV GEMM: P→sPk + V direct → O in TMEM
* 4. Epilogue: O from TMEM → normalize → BF16 → GMEM + LSE
* Double-buffer TMA for K loads: overlap TMA DMA of sub-tile N+1 with
* MMA compute of sub-tile N. V loaded via TMA (single-buffer, tiny tiles).
* Q loaded directly from GMEM.
*/
#pragma once
@@ -34,13 +16,12 @@ namespace dsv4::kernels::attention {
struct FmhaTmaMultiRowParams {
const bf16_t* __restrict__ q;
CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K
CUtensorMap* __restrict__ tma_v; // Array of [n_h] TMA descriptors for V
const bf16_t* __restrict__ v; // V: direct fallback (HD, s_k)
CUtensorMap* __restrict__ tma_k;
CUtensorMap* __restrict__ tma_v;
const bf16_t* __restrict__ v;
bf16_t* __restrict__ o;
float* __restrict__ lse;
int s_k, T, n_h;
int n_kv_tiles; // number of KV tiles (s_k / SK_TILE)
float scale;
int head_dim;
int q_head_stride, q_batch_stride;
@@ -77,13 +58,15 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
const float scale = params.scale;
const bf16_t* __restrict__ q_head = params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride;
// v_head kept for reference; V loaded via TMA
// const bf16_t* __restrict__ v_head = params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride;
// v_head not used — V loaded via TMA
bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride;
float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr;
CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx;
CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx;
// ================================================================
// SMEM allocation — 128-byte aligned for TMA
// SMEM allocation — 128-byte aligned
// ================================================================
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
@@ -97,13 +80,15 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
off = (off + 127) & ~(size_t)127;
bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sK1 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
float* sRowMax = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float);
float* sRowSum = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float);
// TMEM alloc + mbarrier init
// Init
if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
if (tid == 0) {
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
@@ -114,22 +99,38 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx;
CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx;
const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp;
const int my_row = my_warp_active ? (wid * 32 + lane) : 0;
const bool my_row_active = my_warp_active && (my_row < T);
// ================================================================
// QK GEMM → S in TMEM
// QK GEMM — double-buffer TMA pipeline
// ================================================================
for (int kt = 0; kt < NKT_QK; kt++) {
// Load Q: direct from GMEM, all T rows
{
bf16_t* sK_bufs[2] = {sK0, sK1};
int cur_buf = 0;
// Preload kt=0
if (is_load_warp && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k, mbar_addr, 0, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
}
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
__syncthreads();
for (int i = tid; i < TILE_SZ; i += 192) sK0[i] = 0;
for (int i = tid; i < s_k * MMA_K_BF16; i += 192) {
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
sK0[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i];
}
// Load Q for kt=0
if (is_load_warp) {
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
for (int r = 0; r < T; r++) {
for (int d = lane; d < MMA_K_BF16; d += 32) {
int full_d = kt * MMA_K_BF16 + d;
int full_d = d;
if (full_d < HD) {
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q_head[r * HD + full_d];
@@ -137,34 +138,63 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
}
}
}
// Load K: TMA async
if (is_load_warp && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k,
mbar_addr, kt * MMA_K_BF16, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
}
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
__syncthreads();
// Convert TMA row-major → canonical
for (int i = tid; i < TILE_SZ; i += 192) sK0[i] = 0;
for (int i = tid; i < s_k * MMA_K_BF16; i += 192) {
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
sK0[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i];
}
__syncthreads();
// Pipeline loop
for (int kt = 0; kt < NKT_QK; kt++) {
bf16_t* cur_sK = sK_bufs[cur_buf];
bool has_next = (kt + 1 < NKT_QK);
// MMA
if (is_mma_warp) {
uint32_t idesc = make_idesc(128, 128);
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128);
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
// 1. Issue TMA for kt+1 (DMA runs in parallel with MMA)
if (has_next && is_load_warp && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k,
mbar_addr, (kt+1) * MMA_K_BF16, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
}
// 2. MMA QK
if (is_mma_warp) {
uint32_t idesc = make_idesc(128, 128);
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(cur_sK), 128);
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
// 3. Wait for TMA kt+1
if (has_next) {
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
}
__syncthreads();
// 4. Convert → next buffer
if (has_next) {
bf16_t* next_sK = sK_bufs[1 - cur_buf];
for (int i = tid; i < TILE_SZ; i += 192) next_sK[i] = 0;
for (int i = tid; i < s_k * MMA_K_BF16; i += 192) {
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
next_sK[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i];
}
}
// 5. Load Q for kt+1
if (has_next && is_load_warp) {
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
for (int r = 0; r < T; r++) {
for (int d = lane; d < MMA_K_BF16; d += 32) {
int full_d = (kt+1) * MMA_K_BF16 + d;
if (full_d < HD) {
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q_head[r * HD + full_d];
}
}
}
}
cur_buf = 1 - cur_buf;
__syncthreads();
}
__syncthreads();
}
asm volatile("fence.sc.gpu;" ::: "memory");
@@ -173,7 +203,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
// ================================================================
// SOFTMAX — 2-pass, P in registers
// ================================================================
// Pass 1: row_max
float my_row_max = -INFINITY;
if (my_warp_active) {
for (int n = 0; n < NUM_READS; n++) {
@@ -194,7 +223,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
if (my_row_active) sRowMax[my_row] = my_row_max;
__syncthreads();
// Pass 2: exp + sum + P
float my_p_vals[SK_TILE];
float my_row_sum = 0.0f;
if (my_warp_active) {
@@ -222,20 +250,18 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
__syncthreads();
// ================================================================
// PV GEMM — P→sPk + V direct → O in TMEM
// PV GEMM — P→sPk + V TMA → O in TMEM
// ================================================================
for (int n_sub = 0; n_sub < N_NSUB; n_sub++) {
int d_base = n_sub * 16;
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
const int col_start = pv_kt * MMA_K_BF16;
// Zero sPk
if (is_load_warp) {
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
}
__syncthreads();
// Softmax warps: write P to sPk (all active rows)
if (my_row_active) {
for (int c = 0; c < MMA_K_BF16; c++) {
int gc = col_start + c;
@@ -246,7 +272,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
}
__syncthreads();
// Load V sub-tile via TMA
// V via TMA
if (is_load_warp && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_v,
mbar_addr, col_start, d_base);
@@ -255,7 +281,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
__syncthreads();
// Convert sTmaBuf → canonical sV
for (int i = tid; i < V_SUB_SZ; i += 192) sV[i] = 0;
for (int i = tid; i < 16 * MMA_K_BF16; i += 192) {
int dd = i / MMA_K_BF16, lr = i % MMA_K_BF16;
@@ -264,7 +289,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
}
__syncthreads();
// MMA
if (is_mma_warp) {
uint32_t idesc_pv = make_idesc(128, 16);
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
@@ -280,8 +304,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
__syncthreads();
// ================================================================
// EPILOGUE — O from TMEM → normalize → GMEM + LSE
// TMEM loads are warp-collective: MUST be outside my_row_active guard
// EPILOGUE
// ================================================================
if (my_warp_active) {
float rm = my_row_active ? sRowMax[my_row] : 0.0f;

View File

@@ -34,13 +34,14 @@ static size_t compute_smem() {
size_t off = 0;
off += 4; off = (off+127)&~(size_t)127;
off += 16; off = (off+127)&~(size_t)127;
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
off += 16 * MY_MMA_K * 2;
off += 128 * 4;
off += 128 * 4;
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sTmaBuf
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sQ0
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sK0
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sK1 (double buffer)
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sPk
off += 16 * MY_MMA_K * 2; // sV
off += 128 * 4; // sRowMax
off += 128 * 4; // sRowSum
return off;
}