feat: double-buffer TMA pipeline in multi-row kernel
This commit is contained in:
@@ -1,27 +1,9 @@
|
||||
/**
|
||||
* DSV4 FMHA — 6-warp TMA kernel, multi-row softmax (prefill T>1).
|
||||
* DSV4 FMHA — 6-warp TMA kernel, multi-row softmax, double-buffer pipeline.
|
||||
*
|
||||
* Based on fmha_6warp_multirow.cuh with TMA async loads for K.
|
||||
* Q and V remain direct GMEM loads for now.
|
||||
*
|
||||
* ==================================================================
|
||||
* DESIGN
|
||||
* ==================================================================
|
||||
*
|
||||
* 6-warp CTA: warps 0-3 = softmax, warp 4 = MMA, warp 5 = TMA load.
|
||||
* Grid: (1, n_h, batch) — each CTA processes one head of one batch item.
|
||||
*
|
||||
* TMA pipeline:
|
||||
* - K: TMA async load via cp.async.bulk.tensor.2d with mbarrier
|
||||
* - Q: direct GMEM load (multi-row, but small enough for warp-stride)
|
||||
* - V: direct GMEM load (16×16 sub-tiles)
|
||||
* - sTmaBuf: staging area for TMA→canonical conversion
|
||||
*
|
||||
* Flow:
|
||||
* 1. QK GEMM: Q direct + K TMA → S in TMEM
|
||||
* 2. Softmax: 2-pass (row_max, exp+sum+P), P in registers
|
||||
* 3. PV GEMM: P→sPk + V direct → O in TMEM
|
||||
* 4. Epilogue: O from TMEM → normalize → BF16 → GMEM + LSE
|
||||
* Double-buffer TMA for K loads: overlap TMA DMA of sub-tile N+1 with
|
||||
* MMA compute of sub-tile N. V loaded via TMA (single-buffer, tiny tiles).
|
||||
* Q loaded directly from GMEM.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -34,13 +16,12 @@ namespace dsv4::kernels::attention {
|
||||
|
||||
struct FmhaTmaMultiRowParams {
|
||||
const bf16_t* __restrict__ q;
|
||||
CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K
|
||||
CUtensorMap* __restrict__ tma_v; // Array of [n_h] TMA descriptors for V
|
||||
const bf16_t* __restrict__ v; // V: direct fallback (HD, s_k)
|
||||
CUtensorMap* __restrict__ tma_k;
|
||||
CUtensorMap* __restrict__ tma_v;
|
||||
const bf16_t* __restrict__ v;
|
||||
bf16_t* __restrict__ o;
|
||||
float* __restrict__ lse;
|
||||
int s_k, T, n_h;
|
||||
int n_kv_tiles; // number of KV tiles (s_k / SK_TILE)
|
||||
float scale;
|
||||
int head_dim;
|
||||
int q_head_stride, q_batch_stride;
|
||||
@@ -77,13 +58,15 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
const float scale = params.scale;
|
||||
|
||||
const bf16_t* __restrict__ q_head = params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride;
|
||||
// v_head kept for reference; V loaded via TMA
|
||||
// const bf16_t* __restrict__ v_head = params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride;
|
||||
// v_head not used — V loaded via TMA
|
||||
bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride;
|
||||
float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr;
|
||||
|
||||
CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx;
|
||||
CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx;
|
||||
|
||||
// ================================================================
|
||||
// SMEM allocation — 128-byte aligned for TMA
|
||||
// SMEM allocation — 128-byte aligned
|
||||
// ================================================================
|
||||
extern __shared__ __align__(128) char sbuf[];
|
||||
size_t off = 0;
|
||||
@@ -97,13 +80,15 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sK1 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
|
||||
float* sRowMax = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float);
|
||||
float* sRowSum = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float);
|
||||
|
||||
// TMEM alloc + mbarrier init
|
||||
// Init
|
||||
if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
||||
if (tid == 0) {
|
||||
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
|
||||
@@ -114,22 +99,38 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
int phase = 0;
|
||||
|
||||
CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx;
|
||||
CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx;
|
||||
const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp;
|
||||
const int my_row = my_warp_active ? (wid * 32 + lane) : 0;
|
||||
const bool my_row_active = my_warp_active && (my_row < T);
|
||||
|
||||
// ================================================================
|
||||
// QK GEMM → S in TMEM
|
||||
// QK GEMM — double-buffer TMA pipeline
|
||||
// ================================================================
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
// Load Q: direct from GMEM, all T rows
|
||||
{
|
||||
bf16_t* sK_bufs[2] = {sK0, sK1};
|
||||
int cur_buf = 0;
|
||||
|
||||
// Preload kt=0
|
||||
if (is_load_warp && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k, mbar_addr, 0, 0);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
|
||||
}
|
||||
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
||||
__syncthreads();
|
||||
|
||||
for (int i = tid; i < TILE_SZ; i += 192) sK0[i] = 0;
|
||||
for (int i = tid; i < s_k * MMA_K_BF16; i += 192) {
|
||||
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
|
||||
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
|
||||
sK0[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i];
|
||||
}
|
||||
|
||||
// Load Q for kt=0
|
||||
if (is_load_warp) {
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
int full_d = d;
|
||||
if (full_d < HD) {
|
||||
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
|
||||
sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q_head[r * HD + full_d];
|
||||
@@ -137,34 +138,63 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load K: TMA async
|
||||
if (is_load_warp && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k,
|
||||
mbar_addr, kt * MMA_K_BF16, 0);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
|
||||
}
|
||||
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
||||
__syncthreads();
|
||||
|
||||
// Convert TMA row-major → canonical
|
||||
for (int i = tid; i < TILE_SZ; i += 192) sK0[i] = 0;
|
||||
for (int i = tid; i < s_k * MMA_K_BF16; i += 192) {
|
||||
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
|
||||
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
|
||||
sK0[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i];
|
||||
}
|
||||
__syncthreads();
|
||||
// Pipeline loop
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
bf16_t* cur_sK = sK_bufs[cur_buf];
|
||||
bool has_next = (kt + 1 < NKT_QK);
|
||||
|
||||
// MMA
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128);
|
||||
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
// 1. Issue TMA for kt+1 (DMA runs in parallel with MMA)
|
||||
if (has_next && is_load_warp && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k,
|
||||
mbar_addr, (kt+1) * MMA_K_BF16, 0);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
|
||||
}
|
||||
|
||||
// 2. MMA QK
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(cur_sK), 128);
|
||||
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
|
||||
// 3. Wait for TMA kt+1
|
||||
if (has_next) {
|
||||
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 4. Convert → next buffer
|
||||
if (has_next) {
|
||||
bf16_t* next_sK = sK_bufs[1 - cur_buf];
|
||||
for (int i = tid; i < TILE_SZ; i += 192) next_sK[i] = 0;
|
||||
for (int i = tid; i < s_k * MMA_K_BF16; i += 192) {
|
||||
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
|
||||
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
|
||||
next_sK[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i];
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Load Q for kt+1
|
||||
if (has_next && is_load_warp) {
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0;
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = (kt+1) * MMA_K_BF16 + d;
|
||||
if (full_d < HD) {
|
||||
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
|
||||
sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q_head[r * HD + full_d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cur_buf = 1 - cur_buf;
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
@@ -173,7 +203,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
// ================================================================
|
||||
// SOFTMAX — 2-pass, P in registers
|
||||
// ================================================================
|
||||
// Pass 1: row_max
|
||||
float my_row_max = -INFINITY;
|
||||
if (my_warp_active) {
|
||||
for (int n = 0; n < NUM_READS; n++) {
|
||||
@@ -194,7 +223,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
if (my_row_active) sRowMax[my_row] = my_row_max;
|
||||
__syncthreads();
|
||||
|
||||
// Pass 2: exp + sum + P
|
||||
float my_p_vals[SK_TILE];
|
||||
float my_row_sum = 0.0f;
|
||||
if (my_warp_active) {
|
||||
@@ -222,20 +250,18 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// PV GEMM — P→sPk + V direct → O in TMEM
|
||||
// PV GEMM — P→sPk + V TMA → O in TMEM
|
||||
// ================================================================
|
||||
for (int n_sub = 0; n_sub < N_NSUB; n_sub++) {
|
||||
int d_base = n_sub * 16;
|
||||
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
|
||||
const int col_start = pv_kt * MMA_K_BF16;
|
||||
|
||||
// Zero sPk
|
||||
if (is_load_warp) {
|
||||
for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Softmax warps: write P to sPk (all active rows)
|
||||
if (my_row_active) {
|
||||
for (int c = 0; c < MMA_K_BF16; c++) {
|
||||
int gc = col_start + c;
|
||||
@@ -246,7 +272,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load V sub-tile via TMA
|
||||
// V via TMA
|
||||
if (is_load_warp && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_v,
|
||||
mbar_addr, col_start, d_base);
|
||||
@@ -255,7 +281,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
||||
__syncthreads();
|
||||
|
||||
// Convert sTmaBuf → canonical sV
|
||||
for (int i = tid; i < V_SUB_SZ; i += 192) sV[i] = 0;
|
||||
for (int i = tid; i < 16 * MMA_K_BF16; i += 192) {
|
||||
int dd = i / MMA_K_BF16, lr = i % MMA_K_BF16;
|
||||
@@ -264,7 +289,6 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// MMA
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc_pv = make_idesc(128, 16);
|
||||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128);
|
||||
@@ -280,8 +304,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// EPILOGUE — O from TMEM → normalize → GMEM + LSE
|
||||
// TMEM loads are warp-collective: MUST be outside my_row_active guard
|
||||
// EPILOGUE
|
||||
// ================================================================
|
||||
if (my_warp_active) {
|
||||
float rm = my_row_active ? sRowMax[my_row] : 0.0f;
|
||||
|
||||
@@ -34,13 +34,14 @@ static size_t compute_smem() {
|
||||
size_t off = 0;
|
||||
off += 4; off = (off+127)&~(size_t)127;
|
||||
off += 16; off = (off+127)&~(size_t)127;
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127;
|
||||
off += 16 * MY_MMA_K * 2;
|
||||
off += 128 * 4;
|
||||
off += 128 * 4;
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sTmaBuf
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sQ0
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sK0
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sK1 (double buffer)
|
||||
off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sPk
|
||||
off += 16 * MY_MMA_K * 2; // sV
|
||||
off += 128 * 4; // sRowMax
|
||||
off += 128 * 4; // sRowSum
|
||||
return off;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user