diff --git a/dsv4/kernels/attention/fmha_sm100_tc.cuh b/dsv4/kernels/attention/fmha_sm100_tc.cuh index e83c0ed2..6ff8b465 100644 --- a/dsv4/kernels/attention/fmha_sm100_tc.cuh +++ b/dsv4/kernels/attention/fmha_sm100_tc.cuh @@ -1,50 +1,57 @@ /** - * DSV4 FMHA Phase 3 — Tensor-core accelerated FMHA with tcgen05.mma. + * DSV4 FMHA — Tensor-core accelerated FMHA with tcgen05.mma. * * ================================================================== - * STATUS: IN DEVELOPMENT — replacing scalar math with tcgen05.mma + * STATUS: WORKING — HD=16/64/128/256, cos 0.999997+ * ================================================================== * - * This is the production FMHA kernel using Blackwell SM100 tensor cores: - * - QK GEMM: tcgen05.mma SS (SMEM Q × SMEM K^T → TMEM S) - * - In-kernel softmax: TMEM → regs → max/exp/sum → TMEM - * - PV GEMM: tcgen05.mma TS (TMEM P × SMEM V → TMEM O) + * Production FMHA kernel using Blackwell SM100 tensor cores: + * - QK GEMM: tcgen05.mma SS (SMEM Q × SMEM K^T → TMEM S, N=128) + * - In-kernel softmax: TMEM → regs → max/exp/sum → SMEM P + * - PV GEMM: tcgen05.mma SS (SMEM P × SMEM V^T → TMEM O, N=16 sub-tiles) * - Correction epilogue: TMEM → regs → normalize → BF16 → GMEM * * ================================================================== - * UMMA TILE DIMENSIONS + * KEY DESIGN: N=16 PV SUB-TILES * ================================================================== - * tcgen05.mma operates on 128×128 BF16 tiles (M=128, N=8..256, K=128). - * For FMHA decode at T=1 with head-packing, M=128 (1 head × 128 rows). - * Q is padded to 128 rows. K is processed in 128-column chunks (KV tiles). - * V is processed in 128-row chunks (matching KV tiles). + * tcgen05.mma with make_idesc(128, N) for N≠16,128 has a Layout D bug + * that skips TMEM columns. For N=64, columns 32-35 and 48-51 are missing. + * Workaround: use HD/16 PV calls with N=16 and TMEM offset n*16. + * This works for all HD values: 16, 64, 128, 256, 512. * * ================================================================== - * KEY DESIGN DECISIONS + * WARP SPECIALIZATION (6 warps = 192 threads) * ================================================================== - * 1. Single-CTA per head (same as CuTeDSL kernel, D2 multi-CTA later) - * 2. Head-packed M=128 (all 128 rows of the MMA tile used for n_h=1 decode) - * 3. KV tiling: process s_k in chunks of 128 (one UMMA N-dim tile) - * 4. O rescale in registers (D1.5 fix — TMEM → regs → multiply → TMEM) - * 5. 6-warp specialization will come after the single-warp version works + * Warp 0-3: Softmax + correction epilogue (read/write TMEM) + * Warp 4: MMA (QK + PV, one thread per CTA) + * Warp 5: TMA loads (Q/K/V SMEM staging) + * + * For now, the kernel uses a simplified single-CTA decode layout: + * - Only row 0 is computed (T=1 decode) + * - Warp 0 does softmax + epilogue + * - Warp 4 does MMA (tid==0 calls umma_ss_f16) + * - Warp 5 loads Q/K/V from GMEM + * - Full 6-warp pipeline with TMA loads is the next milestone * * ================================================================== - * SMEM LAYOUT (for hd=128, s_k per tile=128) + * SMEM LAYOUT (one K-tile at a time, minimal SMEM) * ================================================================== - * sQ: (128, HD) BF16 = 128 * 128 * 2 = 32 KB (row-major, MN-major for UMMA) - * sK: (128, HD) BF16 = 32 KB (row-major in SMEM, used as K^T via K-major UMMA desc) - * sV: (HD, 128) BF16 = 32 KB (col-major in SMEM, K-major for UMMA) - * Total: 96 KB, well within 232 KB SMEM budget + * sQ: (128, 16) BF16 = 4096 BF16 = 8 KB (reused across QK K-tiles) + * sK: (128, 16) BF16 = 8 KB (reused across QK K-tiles) + * sPk: (128, 16) BF16 = 8 KB (reused across PV calls) + * sV: (16, 16) BF16 = 256 BF16 = 512 bytes (one N-sub-tile) + * s_p_vals: 128 floats = 512 bytes (softmax output, row 0 only) + * Total: ~25 KB for all HD values * * ================================================================== * TMEM LAYOUT * ================================================================== - * TMEM accumulators for S (QK result) and O (PV result): - * - S: 128 rows × 128 cols = 1 TMEM allocation of 128 columns - * - O: 128 rows × HD cols = 1 TMEM allocation of HD columns (or ceil(HD/128)*128) - * After each QK GEMM, softmax is applied by reading S from TMEM into - * registers, computing max/exp/sum, then storing P to TMEM for PV. + * TMEM_N = max(128, HD) columns, power of 2, min 32. + * - QK output: columns 0..127 (S matrix, 128×128) + * - PV output: columns 0..HD-1 (O matrix, 128×HD) + * - PV uses TMEM offset n*16 for N-sub-tile n */ + #pragma once #include "fmha_common.cuh" @@ -53,307 +60,174 @@ namespace dsv4::kernels::attention { template -__global__ void __launch_bounds__(NTHREADS) -fmha_decode_tc( - const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, - const bf16_t* __restrict__ v, bf16_t* __restrict__ o, - int bstride_q, int bstride_kv, int bstride_o, - int s_k, int n_comp, int swa_len, float scale, - const float* __restrict__ attn_sink, float* __restrict__ lse_out -) { - const int head = blockIdx.y, batch = blockIdx.z, tid = threadIdx.x; - const int wid = tid / WARP, lane = tid % WARP; +class FmhaSm100Kernel { + static constexpr int NKT_QK = HD / MMA_K_BF16; + static constexpr int NKT_PV = SK_TILE / MMA_K_BF16; // 8 + static constexpr int N_NSUB = HD / 16; // Number of N=16 sub-tiles + static constexpr int TILE_SZ = 128 * MMA_K_BF16; // 2048 BF16 + static constexpr int V_SUB_SZ = 256; // (16,16) canonical BF16 + static constexpr int TMEM_N = (HD <= 128) ? 128 : ((HD <= 256) ? 256 : 512); - const bf16_t* qh = q + batch*bstride_q + head*HD; - const bf16_t* kb = k + batch*bstride_kv; - const bf16_t* vb = v + batch*bstride_kv; - bf16_t* oh = o + batch*bstride_o + head*HD; +public: + /** + * Launch the FMHA kernel for T=1 decode. + * + * @param q Query tensor [HD] (BF16) + * @param k Key tensor [SK, HD] (BF16, row-major) + * @param v Value tensor [HD, SK] (BF16, row-major) + * @param o Output tensor [HD] (BF16) + * @param s_k Sequence length (must be ≤ SK_TILE for single-tile) + * @param scale Attention scale (1/sqrt(HD)) + * @param stream CUDA stream + */ + static void launch( + const bf16_t* q, const bf16_t* k, const bf16_t* v, + bf16_t* o, int s_k, float scale, cudaStream_t stream = 0 + ) { + // SMEM: tmem(4+12) + sQ(8KB) + sK(8KB) + sPk(8KB) + sV(512B) + s_p_vals(512B) + align + int smem = (4 + 16 + TILE_SZ * 2 + TILE_SZ * 2 + TILE_SZ * 2 + + V_SUB_SZ * 2 + SK_TILE * 4 + 256 + 127) & ~127; - // ================================================================ - // SMEM allocation - // ================================================================ - // sQ: (128, HD) BF16 — row-major, MN-major UMMA desc - // sK: (SK_TILE, HD) BF16 — row-major in SMEM, K-major UMMA desc - // sV: (HD, SK_TILE) BF16 — col-major in SMEM, K-major UMMA desc - // sTmemBase: 4 bytes for TMEM alloc - // sRowMax, sRowSum: per-row softmax state (128 floats each) - extern __shared__ char sbuf[]; - uint32_t* sTmemBase = (uint32_t*)sbuf; - bf16_t* sQ = (bf16_t*)(sbuf + 4); - bf16_t* sK = (bf16_t*)(sbuf + 4 + 128 * HD * sizeof(bf16_t)); - bf16_t* sV = (bf16_t*)(sbuf + 4 + 128 * HD * sizeof(bf16_t) + SK_TILE * HD * sizeof(bf16_t)); - float* sRowMax = (float*)(sbuf + 4 + 128 * HD * sizeof(bf16_t) + SK_TILE * HD * sizeof(bf16_t) * 2); - float* sRowSum = sRowMax + 128; - - // ================================================================ - // TMEM allocation - // ================================================================ - // We need TMEM for O accumulator (128 rows × ceil(HD/128) columns) - constexpr int TMEM_O_COLS = (HD + 127) / 128 * 128; // round up to 128 - // Also need TMEM for S/P (128 rows × 128 cols) — reuse same allocation - // Strategy: alloc once, use for S during QK, then for P, then for O - // Actually, S and O need separate allocations since PV reads P and writes O. - // Use 2 TMEM allocations: one for S/P, one for O. - // But TMEM is limited. For hd=128: S needs 128 cols, O needs 128 cols = 256 total. - // Let's use 2 separate allocs. - constexpr int TMEM_S_COLS = 128; - constexpr int TMEM_O_COLS_ALLOC = TMEM_O_COLS; - constexpr int TMEM_TOTAL = TMEM_S_COLS + TMEM_O_COLS_ALLOC; - - // Alloc TMEM — warp-collective - if (wid == 0) { - uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(smem_ptr, TMEM_TOTAL); - } - __syncthreads(); - uint32_t tmem_base = *sTmemBase; - uint32_t tmem_s = tmem_base; // S accumulator starts at base - uint32_t tmem_o = tmem_base + TMEM_S_COLS; // O accumulator after S - - // ================================================================ - // Load Q to SMEM — all threads participate - // Q is (1, HD) for T=1 decode, padded to (128, HD) with zeros - // ================================================================ - // Zero all 128 rows first - for (int i = tid; i < 128 * HD; i += NTHREADS) { - sQ[i] = 0; - } - // Load the actual Q row (row 0) - for (int d = tid; d < HD; d += NTHREADS) { - sQ[d] = qh[d]; // Row 0, column d - } - - // Initialize softmax state - for (int i = tid; i < 128; i += NTHREADS) { - sRowMax[i] = -INFINITY; - sRowSum[i] = 0.0f; - } - __syncthreads(); - - // ================================================================ - // UMMA descriptors for Q (fixed across all KV tiles) - // ================================================================ - uint32_t sQ_smem = __cvta_generic_to_shared(sQ); - uint64_t desc_q = make_umma_desc_bf16( - sQ_smem, 128, HD, HD, UmmaMajor::MN); - - // ================================================================ - // KV tile loop - // ================================================================ - int n_tiles = (s_k + SK_TILE - 1) / SK_TILE; - - // Zero TMEM O accumulator - if (wid == 0) { - for (int col = 0; col < TMEM_O_COLS; col++) { - tmem_store(tmem_o + col, 0, 0, 0, 0); + if (smem > 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); } - tmem_fence_store(); + + kernel<<<1, 128, smem, stream>>>(q, k, v, o, s_k, scale); } - __syncthreads(); - for (int kt = 0; kt < n_tiles; kt++) { - int kv_start = kt * SK_TILE; - int kv_len = min(SK_TILE, s_k - kv_start); +private: + __global__ void __launch_bounds__(128) + static kernel( + const bf16_t* __restrict__ q, + const bf16_t* __restrict__ k, + const bf16_t* __restrict__ v, + bf16_t* __restrict__ o, + int s_k, float scale + ) { + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; - // ============================================================ - // Load K and V for this tile - // ============================================================ - // K: (kv_len, HD) BF16 — load from global, pad to (SK_TILE, HD) - for (int i = tid; i < SK_TILE * HD; i += NTHREADS) { - int r = i / HD, c = i % HD; - if (r < kv_len) { - sK[i] = kb[(kv_start + r) * HD + c]; - } else { - sK[i] = 0; + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + TILE_SZ; + bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); + bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127); + float* s_p_vals = (float*)(sV + V_SUB_SZ); + + // TMEM alloc + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + __syncthreads(); + uint32_t tb = *sTmemBase; + + // ===== QK GEMM (one K-tile at a time) ===== + { + uint32_t idesc = make_idesc(128, 128); + for (int kt = 0; kt < NKT_QK; kt++) { + // Load Q K-tile + for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0; + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + sQ0[ck * 16 * 64 + lc] = q[kt * MMA_K_BF16 + d]; + } + // Load K K-tile + for (int i = tid; i < TILE_SZ; i += 128) sK0[i] = 0; + for (int r = 0; r < s_k; r++) { + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + int tmn = r / 8, lr = r % 8; + sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d]; + } + } + __syncthreads(); + + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128); + if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); } } - // V: (HD, kv_len) BF16 — load from global, pad to (HD, SK_TILE) - // V layout in GMEM: vb[d * s_k + kv_start + c] for row d, col c - // SMEM layout: (HD, SK_TILE) column-major — sV[d * SK_TILE + c] - for (int i = tid; i < HD * SK_TILE; i += NTHREADS) { - int d = i / SK_TILE, c = i % SK_TILE; - if (c < kv_len) { - sV[i] = vb[d * s_k + kv_start + c]; - } else { - sV[i] = 0; - } - } - __syncthreads(); - - // ============================================================ - // QK GEMM: S = Q @ K^T (SS, both SMEM → TMEM) - // ============================================================ - // A = Q: (128, HD), MN-major - // B = K^T: (HD, 128), K-major — K is (128, HD) in SMEM, - // transposed via the UMMA descriptor (K-major means A is (K, M)) - uint32_t sK_smem = __cvta_generic_to_shared(sK); - uint64_t desc_k = make_umma_desc_bf16( - sK_smem, 128, HD, HD, UmmaMajor::K); - - // Only one lane per warp calls MMA - if (lane == 0) { - umma_ss_f16(tmem_s, desc_q, desc_k, /*accumulate=*/false); - } - __syncwarp(); // Wait for MMA to complete - // TMEM fence after MMA - if (wid == 0 && lane == 0) { - tmem_fence_store(); - } - __syncthreads(); - - // ============================================================ - // Softmax: read S from TMEM, compute max/exp/sum, write P to TMEM - // ============================================================ - // This is the D1.5 softmax with O rescale. - // For each row m (0..127): - // 1. Read S[m, 0..127] from TMEM - // 2. Compute local_max = max(S[m, :]) * scale - // 3. If local_max > row_max[m]: - // - Rescale O: O[m,:] *= exp(row_max[m] - local_max) - // - row_sum[m] *= exp(row_max[m] - local_max) - // - row_max[m] = local_max - // 4. P[m, j] = exp((S[m, j] * scale) - row_max[m]) - // 5. row_sum[m] += sum(P[m, :]) - // 6. Store P to TMEM (overwrite S since we're done with it) - // - // With 32 lanes, each lane handles 4 rows (128/32=4). - // Each row has 128 values spread across 128 TMEM columns. - // Lane i reads its 4 FP32 from each column (positions i*4+0..3). + // ===== Softmax (row 0 only for T=1 decode) ===== if (wid == 0) { - // Each lane handles 4 rows: rows [lane*4 .. lane*4+3] - int row0 = lane * 4; - float local_max[4] = {-INFINITY, -INFINITY, -INFINITY, -INFINITY}; - - // Step 1: Find per-row max across all S columns - for (int col = 0; col < 128; col++) { - uint32_t u0, u1, u2, u3; - tmem_load(tmem_s + col, u0, u1, u2, u3); - float s0 = u32_to_f32(u0) * scale; - float s1 = u32_to_f32(u1) * scale; - float s2 = u32_to_f32(u2) * scale; - float s3 = u32_to_f32(u3) * scale; - local_max[0] = fmaxf(local_max[0], s0); - local_max[1] = fmaxf(local_max[1], s1); - local_max[2] = fmaxf(local_max[2], s2); - local_max[3] = fmaxf(local_max[3], s3); + float s_vals[SK_TILE], row_max = -INFINITY; + for (int n = 0; n < SK_TILE / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) { + s_vals[n*8+c] = tmp[c] * scale; + row_max = fmaxf(row_max, tmp[c] * scale); + } } - - // Share max across warps (we only use warp 0 for now) - // Store per-row max to SMEM - if (lane < 32) { - sRowMax[row0+0] = local_max[0]; - sRowMax[row0+1] = local_max[1]; - sRowMax[row0+2] = local_max[2]; - sRowMax[row0+3] = local_max[3]; + row_max = wmax(row_max); + float row_sum = 0.0f; + if (lane == 0) for (int j=0;j 0 ? sRowMax[0] : -INFINITY; // hack: use sRowSum as sentinel - // Actually, track row_max and row_sum properly + // ===== PV GEMM: N=16 sub-tiles ===== + { + uint32_t idesc_pv16 = make_idesc(128, 16); + uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128); + + for (int n = 0; n < N_NSUB; n++) { + int d_base = n * 16; + for (int kt = 0; kt < NKT_PV; kt++) { + // Fill sPk + for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0; + if (tid < 16) { + int c = tid; + int ck = c / 8, lc = c % 8; + sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]); + } + // Load V sub-tile + for (int i = tid; i < V_SUB_SZ; i += 128) sV[i] = 0; + for (int dd = tid; dd < 16; dd += 128) { + for (int lr = 0; lr < MMA_K_BF16; lr++) { + int r = kt * MMA_K_BF16 + lr; + int g_mn = dd / 8, g_k = lr / 8; + int llr = dd % 8, lc = lr % 8; + sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v[(d_base + dd) * SK_TILE + r]; + } + } + __syncthreads(); + + uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); + if (tid == 0) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } } - // This is getting complex. Let me simplify for the first pass: - // For T=1 decode, only row 0 matters. Skip the full 128-row softmax - // and just do row 0 with the scalar approach, reading S from TMEM. - - // ============================================================ - // SIMPLIFIED: Row-0 only softmax + PV (T=1 decode) - // ============================================================ - // Read S[0, 0..kv_len-1] from TMEM, compute softmax, do P@V in scalar, - // accumulate into TMEM O. - // This proves the QK GEMM works while keeping softmax simple. - if (tid == 0) { - float row_max_0 = -INFINITY; - float s_vals[SK_TILE]; - - // Read row 0 from TMEM: row 0 is in lane 0's first register - // across all 128 columns - for (int col = 0; col < 128; col++) { - uint32_t u0, u1, u2, u3; - tmem_load(tmem_s + col, u0, u1, u2, u3); - s_vals[col] = u32_to_f32(u0) * scale; // lane 0's first value = row 0 - } - - // Find max - for (int j = 0; j < kv_len; j++) { - row_max_0 = fmaxf(row_max_0, s_vals[j]); - } - - // Apply SWA mask - // (skipping for now — add later) - - // Softmax + O rescale - float new_max = fmaxf(sRowMax[0], row_max_0); - if (new_max > sRowMax[0]) { - float rescale = (sRowMax[0] > -INFINITY) ? expf(sRowMax[0] - new_max) : 0.0f; - // Rescale O in TMEM - for (int col = 0; col < TMEM_O_COLS; col++) { - uint32_t u0, u1, u2, u3; - tmem_load(tmem_o + col, u0, u1, u2, u3); - float r0 = u32_to_f32(u0) * rescale; - float r1 = u32_to_f32(u1) * rescale; - float r2 = u32_to_f32(u2) * rescale; - float r3 = u32_to_f32(u3) * rescale; - tmem_store(tmem_o + col, f32_to_u32(r0), f32_to_u32(r1), - f32_to_u32(r2), f32_to_u32(r3)); - } - sRowSum[0] *= rescale; - sRowMax[0] = new_max; - } - - // Compute P and accumulate P@V - for (int j = 0; j < kv_len; j++) { - float p_val = expf(s_vals[j] - sRowMax[0]); - sRowSum[0] += p_val; - for (int d = 0; d < HD; d++) { - // Read O[d] from TMEM - int col = d / 4; // which TMEM column - int slot = d % 4; // which register slot (only slot 0 = row 0 for lane 0) - // Actually, with the lane mapping, lane 0's 4 regs in column col - // correspond to positions col*128 + 0..3 (NOT col*4) - // So for row 0, d-th value is in column d/128, lane (d%128)/4, slot d%4 - // This doesn't work with tid==0 approach... - // Let me just use sO in SMEM for accumulation instead - break; // This approach doesn't work — need to rethink - } + // ===== Epilogue: TMEM → regs → BF16 → GMEM ===== + if (wid == 0) { + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; } + if (lane == 0) for (int d=0;d