diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh new file mode 100644 index 00000000..2ad3a8b2 --- /dev/null +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh @@ -0,0 +1,306 @@ +/** + * DSV4 FMHA — 6-warp TMA kernel, multi-row softmax (prefill T>1). + * + * Based on fmha_6warp_multirow.cuh with TMA async loads for K. + * Q and V remain direct GMEM loads for now. + * + * ================================================================== + * DESIGN + * ================================================================== + * + * 6-warp CTA: warps 0-3 = softmax, warp 4 = MMA, warp 5 = TMA load. + * Grid: (1, n_h, batch) — each CTA processes one head of one batch item. + * + * TMA pipeline: + * - K: TMA async load via cp.async.bulk.tensor.2d with mbarrier + * - Q: direct GMEM load (multi-row, but small enough for warp-stride) + * - V: direct GMEM load (16×16 sub-tiles) + * - sTmaBuf: staging area for TMA→canonical conversion + * + * Flow: + * 1. QK GEMM: Q direct + K TMA → S in TMEM + * 2. Softmax: 2-pass (row_max, exp+sum+P), P in registers + * 3. PV GEMM: P→sPk + V direct → O in TMEM + * 4. Epilogue: O from TMEM → normalize → BF16 → GMEM + LSE + */ + +#pragma once + +#include "fmha_common.cuh" +#include "fmha_umma_desc.cuh" +#include "fmha_tma.cuh" + +namespace dsv4::kernels::attention { + +struct FmhaTmaMultiRowParams { + const bf16_t* __restrict__ q; + CUtensorMap* __restrict__ tma_k; // K: (s_k, HD) with tile (128, 16) + const bf16_t* __restrict__ v; // V: direct GMEM (HD, s_k) + bf16_t* __restrict__ o; + float* __restrict__ lse; + int s_k, T; + float scale; + int head_dim; + int q_head_stride, q_batch_stride; + int k_head_stride, k_batch_stride; + int v_head_stride, v_batch_stride; + int o_head_stride, o_batch_stride; + int lse_head_stride, lse_batch_stride; +}; + +template +__global__ void __launch_bounds__(192) +fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { + static constexpr int NKT_QK = HD / MMA_K_BF16; + static constexpr int NKT_PV = SK_TILE / MMA_K_BF16; + static constexpr int N_NSUB = HD / 16; + static constexpr int TILE_SZ = 128 * MMA_K_BF16; + static constexpr int V_SUB_SZ = 16 * MMA_K_BF16; + static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + static constexpr int MAX_ROWS = 128; + static constexpr int CORES_MN = 128 / 8; + static constexpr int NUM_READS = SK_TILE / 8; + static constexpr int TMA_TILE_BYTES = TILE_SZ * sizeof(bf16_t); + + const int head_idx = blockIdx.y; + const int batch_idx = blockIdx.z; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + const bool is_softmax_warp = (wid < 4); + const bool is_mma_warp = (wid == 4); + const bool is_load_warp = (wid == 5); + const int T = params.T; + const int s_k = params.s_k; + const float scale = params.scale; + + const bf16_t* __restrict__ q_head = params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride; + const bf16_t* __restrict__ v_head = params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride; + bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride; + float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr; + + // ================================================================ + // SMEM allocation — 128-byte aligned for TMA + // ================================================================ + extern __shared__ __align__(128) char sbuf[]; + size_t off = 0; + uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4; + off = (off + 127) & ~(size_t)127; + uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 16; + off = (off + 127) & ~(size_t)127; + bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sQ0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t); + float* sRowMax = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float); + float* sRowSum = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float); + + // TMEM alloc + mbarrier init + if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + if (tid == 0) { + tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + int phase = 0; + + // Row assignment + const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp; + const int my_row = my_warp_active ? (wid * 32 + lane) : 0; + const bool my_row_active = my_warp_active && (my_row < T); + + // ================================================================ + // QK GEMM → S in TMEM + // ================================================================ + for (int kt = 0; kt < NKT_QK; kt++) { + // Load Q: direct from GMEM, all T rows + if (is_load_warp) { + for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; + for (int r = 0; r < T; r++) { + for (int d = lane; d < MMA_K_BF16; d += 32) { + int full_d = kt * MMA_K_BF16 + d; + if (full_d < HD) { + int ck = d/8, lc = d%8, cm = r/8, lr = r%8; + sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q_head[r * HD + full_d]; + } + } + } + } + + // Load K: TMA async + if (is_load_warp && lane == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)params.tma_k, + mbar_addr, kt * MMA_K_BF16, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); + } + tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; + __syncthreads(); + + // Convert TMA row-major → canonical + for (int i = tid; i < TILE_SZ; i += 192) sK0[i] = 0; + for (int i = tid; i < s_k * MMA_K_BF16; i += 192) { + int r = i / MMA_K_BF16, c = i % MMA_K_BF16; + int ck = c/8, lc = c%8, tmn = r/8, lr = r%8; + sK0[ck*CORES_MN*64 + tmn*64 + lr*8 + lc] = sTmaBuf[i]; + } + __syncthreads(); + + // MMA + if (is_mma_warp) { + uint32_t idesc = make_idesc(128, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // ================================================================ + // SOFTMAX — 2-pass, P in registers + // ================================================================ + // Pass 1: row_max + float my_row_max = -INFINITY; + if (my_warp_active) { + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) my_row_max = fmaxf(my_row_max, tmp[c] * scale); + } + } + } + } + if (my_row_active) sRowMax[my_row] = my_row_max; + __syncthreads(); + + // Pass 2: exp + sum + P + float my_p_vals[SK_TILE]; + float my_row_sum = 0.0f; + if (my_warp_active) { + float rm = my_row_active ? sRowMax[my_row] : 0.0f; + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) { + float p = expf(tmp[c] * scale - rm); + my_p_vals[col] = p; + my_row_sum += p; + } + } + } + } + } + if (my_row_active) sRowSum[my_row] = my_row_sum; + __syncthreads(); + + // ================================================================ + // PV GEMM — P→sPk + V direct → O in TMEM + // ================================================================ + for (int n_sub = 0; n_sub < N_NSUB; n_sub++) { + int d_base = n_sub * 16; + for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) { + const int col_start = pv_kt * MMA_K_BF16; + + // Zero sPk + if (is_load_warp) { + for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0; + } + __syncthreads(); + + // Softmax warps: write P to sPk (all active rows) + if (my_row_active) { + for (int c = 0; c < MMA_K_BF16; c++) { + int gc = col_start + c; + int ck = c/8, lc = c%8; + int core_mn = my_row/8, local_r = my_row%8; + sPk[ck*CORES_MN*64 + core_mn*64 + local_r*8 + lc] = f32_to_bf16(my_p_vals[gc]); + } + } + __syncthreads(); + + // Load V sub-tile: direct from GMEM + if (is_load_warp) { + for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0; + for (int dd = lane; dd < 16; dd += 32) { + for (int lr = 0; lr < MMA_K_BF16; lr++) { + int r = col_start + lr; + if (r < s_k && (d_base + dd) < HD) { + int g_mn = dd/8, g_k = lr/8, llr = dd%8, lc = lr%8; + sV[g_k*2*64 + g_mn*64 + llr*8 + lc] = v_head[(d_base+dd)*s_k + r]; + } + } + } + } + __syncthreads(); + + // MMA + if (is_mma_warp) { + uint32_t idesc_pv = make_idesc(128, 16); + uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128); + uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); + if (tid == 128) umma_ss_f16(tb + n_sub*16, dp, dv, idesc_pv, pv_kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + } + + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // ================================================================ + // EPILOGUE — O from TMEM → normalize → GMEM + LSE + // TMEM loads are warp-collective: MUST be outside my_row_active guard + // ================================================================ + if (my_warp_active) { + float rm = my_row_active ? sRowMax[my_row] : 0.0f; + float rs = my_row_active ? sRowSum[my_row] : 0.0f; + float inv_rs = my_row_active ? (1.0f / rs) : 0.0f; + + for (int n = 0; n < N_NSUB * 2; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int d = n * 8 + c; + if (d < HD) o_head[my_row * HD + d] = f32_to_bf16(tmp[c] * inv_rs); + } + } + } + if (my_row_active && lse_head) lse_head[my_row] = logf(rs) + rm; + } + __syncthreads(); + + if (is_mma_warp) tmem_dealloc(tb, TMEM_N); +} + +} // namespace diff --git a/tests/unit/test_fmha_6warp_tma_multirow.cu b/tests/unit/test_fmha_6warp_tma_multirow.cu new file mode 100644 index 00000000..8fe6fe0f --- /dev/null +++ b/tests/unit/test_fmha_6warp_tma_multirow.cu @@ -0,0 +1,163 @@ +/** + * Test 6-warp TMA FMHA multi-row kernel (T>1 prefill). + */ + +#include +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" +#include "dsv4/kernels/attention/fmha_tma.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int MAX_T = 128; +constexpr int MY_MMA_K = 16; +constexpr int TILE_SZ = 128 * MY_MMA_K; + +#include "dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh" + +static size_t compute_smem() { + size_t off = 0; + off += 4; off = (off+127)&~(size_t)127; + off += 16; off = (off+127)&~(size_t)127; + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sTmaBuf + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sQ0 + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sK0 + off += TILE_SZ * 2; off = (off+127)&~(size_t)127; // sPk + off += 16 * MY_MMA_K * 2; // sV + off += 128 * 4; // sRowMax + off += 128 * 4; // sRowSum + return off; +} + +static void reference_attention( + const bf16_t* q, const bf16_t* k, const bf16_t* v, + float* o_ref, float* lse_ref, + int hd, int T, int s_k, float scale +) { + for (int t = 0; t < T; t++) { + float s[512]; + for (int j = 0; j < s_k; j++) { + float dot = 0.0f; + for (int d = 0; d < hd; d++) dot += bf16_to_f32_host(q[t*hd+d]) * bf16_to_f32_host(k[j*hd+d]); + s[j] = dot * scale; + } + float mx = -INFINITY; + for (int j = 0; j < s_k; j++) mx = fmaxf(mx, s[j]); + float sm = 0.0f; + for (int j = 0; j < s_k; j++) { s[j] = expf(s[j] - mx); sm += s[j]; } + for (int j = 0; j < s_k; j++) s[j] /= sm; + for (int d = 0; d < hd; d++) { + float ov = 0.0f; + for (int j = 0; j < s_k; j++) ov += s[j] * bf16_to_f32_host(v[d*s_k+j]); + o_ref[t * hd + d] = ov; + } + if (lse_ref) lse_ref[t] = logf(sm) + mx; + } +} + +int main() { + printf("=== 6-warp TMA FMHA multi-row HD=%d ===\n", HD); + const float SCALE = 1.0f / sqrtf((float)HD); + + int total_fail = 0; + + // Test T=1,4,32,128 + for (int T : {1, 4, 32, 128}) { + printf("\n--- T=%d ---\n", T); + + bf16_t* h_q = (bf16_t*)calloc(MAX_T * HD, sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t)); + bf16_t* h_v = (bf16_t*)calloc(HD * SK, sizeof(bf16_t)); + bf16_t* h_o = (bf16_t*)calloc(MAX_T * HD, sizeof(bf16_t)); + float* h_lse = (float*)calloc(MAX_T, sizeof(float)); + + srand(42); + for (int i=0;i 48*1024) + cudaFuncSetAttribute(fmha_6warp_tma_multirow_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem); + + fmha_6warp_tma_multirow_kernel<<<1, 192, smem>>>(params); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf(" CUDA ERROR: %s\n", cudaGetErrorString(err)); + total_fail++; continue; + } + + cudaMemcpy(h_o, d_o, T*HD*sizeof(bf16_t), cudaMemcpyHostToHost); + cudaMemcpy(h_o, d_o, T*HD*sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_lse, d_lse, T*sizeof(float), cudaMemcpyDeviceToHost); + + // Reference + float* o_ref = (float*)calloc(T*HD, sizeof(float)); + reference_attention(h_q, h_k, h_v, o_ref, nullptr, HD, T, SK, SCALE); + + // Check + float cs=0,na=0,nb=0; int bad=0; + for (int t=0;t 1e-4f) { cs+=a*b; na+=a*a; nb+=b*b; } + float rel = fabsf(b)>1e-4f ? fabsf(a-b)/fabsf(b) : fabsf(a-b); + if (rel > 0.01f) bad++; + } + } + cs /= (sqrtf(na)*sqrtf(nb)+1e-10f); + printf(" T=%d: cosine=%.8f bad=%d %s\n", T, cs, bad, bad==0&&cs>0.999f?"PASS":"FAIL"); + if (cs < 0.999f || bad > 0) total_fail++; + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(o_ref); + } + + printf("\nOverall: %s\n", total_fail==0?"ALL PASSED":"SOME FAILED"); + return total_fail == 0 ? 0 : 1; +}