diff --git a/dsv4/kernels/attention/fmha_6warp.cuh b/dsv4/kernels/attention/fmha_6warp.cuh new file mode 100644 index 00000000..bb9c86c8 --- /dev/null +++ b/dsv4/kernels/attention/fmha_6warp.cuh @@ -0,0 +1,235 @@ +/** + * DSV4 FMHA — 6-warp specialized kernel for Blackwell SM100. + * + * ================================================================== + * WARP SPECIALIZATION + * ================================================================== + * Warp 0-3 (tid 0-127): Softmax + correction + epilogue + * - Read S from TMEM, compute softmax, write P to SMEM + * - After PV: read O from TMEM, normalize, write to GMEM + * - For T=1 decode: only warp 0 processes row 0 + * + * Warp 4 (tid 128-159): MMA (QK + PV) + * - Call tcgen05.mma for QK and PV + * - TMEM alloc/dealloc + * - Only 1 thread calls MMA, but TMEM ops are warp-collective + * + * Warp 5 (tid 160-191): Data staging (Q/K/V loads) + * - Load Q, K, V from GMEM to SMEM in canonical layout + * - Future: TMA loads with mbarrier + * - Fill sPk from s_p_vals + * + * ================================================================== + * SYNCHRONIZATION + * ================================================================== + * CTA-wide __syncthreads() barriers between phases: + * 1. After Q/K/V loads → QK MMA + * 2. After QK MMA → softmax + * 3. After softmax + P fill + V load → PV MMA + * 4. After PV MMA → epilogue + * + * Future: mbarrier-based producer-consumer sync between warp 5 (producer) + * and warp 4 (consumer) for pipeline overlap. + * + * ================================================================== + * SMEM LAYOUT (shared across all warps) + * ================================================================== + * sQ: (128, 16) canonical = 8 KB (1 K-tile, reused) + * sK: (128, 16) canonical = 8 KB (1 K-tile, reused) + * sPk: (128, 16) canonical = 8 KB (1 sub-tile, reused) + * sV: (16, 16) canonical = 512 bytes (1 N-sub-tile) + * s_p_vals: 128 floats = 512 bytes (softmax output) + * sRowMax: 1 float (row 0 max, for T=1) + * sRowSum: 1 float (row 0 sum, for T=1) + * sTmemBase: 4 bytes (TMEM allocation) + * Total: ~26 KB + */ + +#pragma once + +#include "fmha_common.cuh" +#include "fmha_umma_desc.cuh" + +namespace dsv4::kernels::attention { + +template +__global__ void __launch_bounds__(192) +fmha_6warp_kernel( + const bf16_t* __restrict__ q, + const bf16_t* __restrict__ k, + const bf16_t* __restrict__ v, + bf16_t* __restrict__ o, + int s_k, float scale +) { + static constexpr int NKT_QK = HD / MMA_K_BF16; + static constexpr int NKT_PV = SK_TILE / MMA_K_BF16; // 8 + static constexpr int N_NSUB = HD / 16; + static constexpr int TILE_SZ = 128 * MMA_K_BF16; // 2048 BF16 + static constexpr int V_SUB_SZ = 256; // (16,16) canonical BF16 + static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + + // Warp role predicates + const bool is_softmax_warp = (wid < 4); // Warps 0-3 + const bool is_mma_warp = (wid == 4); // Warp 4 + const bool is_load_warp = (wid == 5); // Warp 5 + + // ================================================================ + // SMEM allocation (shared across all warps) + // ================================================================ + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + float* sRowMax = (float*)(sbuf + 4); + float* sRowSum = sRowMax + 1; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + 1) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + TILE_SZ; + bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); + bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127); + float* s_p_vals = (float*)(sV + V_SUB_SZ); + + // ================================================================ + // TMEM allocation (warp 4) + // ================================================================ + if (is_mma_warp) { + uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); + tmem_alloc(smem_ptr, TMEM_N); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + + // ================================================================ + // QK GEMM loop: for each K-tile, load Q+K, then MMA + // ================================================================ + for (int kt = 0; kt < NKT_QK; kt++) { + // ---- Warp 5: Load Q and K for this K-tile ---- + if (is_load_warp) { + // Load Q K-tile + for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; + for (int d = lane; d < MMA_K_BF16; d += 32) { + int ck = d / 8, lc = d % 8; + sQ0[ck * 16 * 64 + lc] = q[kt * MMA_K_BF16 + d]; + } + // Load K K-tile + for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0; + for (int r = 0; r < s_k; r++) { + for (int d = lane; d < MMA_K_BF16; d += 32) { + int ck = d / 8, lc = d % 8; + int tmn = r / 8, lr = r % 8; + sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d]; + } + } + } + __syncthreads(); // Wait for loads + + // ---- Warp 4: QK MMA ---- + if (is_mma_warp) { + uint32_t idesc = make_idesc(128, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); // Wait for MMA + } + + // ================================================================ + // Softmax (warp 0, row 0 only for T=1 decode) + // ================================================================ + if (wid == 0) { + float s_vals[SK_TILE], row_max = -INFINITY; + for (int n = 0; n < SK_TILE / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) { + s_vals[n*8+c] = tmp[c] * scale; + row_max = fmaxf(row_max, tmp[c] * scale); + } + } + row_max = wmax(row_max); + if (lane == 0) *sRowMax = row_max; + float row_sum = 0.0f; + if (lane == 0) for (int j=0;j 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); // Wait for MMA + } + } + + // ================================================================ + // Epilogue: TMEM → regs → normalize → BF16 → GMEM (warp 0) + // ================================================================ + if (wid == 0) { + float inv_sum = 1.0f / *sRowSum; + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * inv_sum; + } + if (lane == 0) for (int d=0;d +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int NKT_QK = HD / MMA_K_BF16; +constexpr int NKT_PV = SK / MMA_K_BF16; +constexpr int TILE_SZ = 128 * MMA_K_BF16; +constexpr int N_NSUB = HD / 16; +constexpr int V_SUB_SZ = 256; +constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + +// Include the kernel +#include "dsv4/kernels/attention/fmha_6warp.cuh" + +int main() { + printf("=== 6-warp FMHA HD=%d ===\n", HD); + const float SCALE = 1.0f / sqrtf((float)HD); + + bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t)); + bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t)); + bf16_t* h_o = (bf16_t*)calloc(HD, sizeof(bf16_t)); + + srand(42); + for (int d=0;d 48 * 1024) { + cudaFuncSetAttribute(fmha_6warp_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + } + fmha_6warp_kernel<<<1, 192, smem>>>(d_q, d_k, d_v, d_o, SK, SCALE); + + cudaError_t launch_err = cudaGetLastError(); + if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); return 1; } + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost); + + printf("O[0..7] MMA: "); for(int d=0;d1e-4f) { cs+=a*b; na+=a*a; nb+=b*b; } + } + cs /= (sqrtf(na)*sqrtf(nb)+1e-10f); + printf("Filtered cosine: %.8f\n", cs); + printf("Test %s\n", cs > 0.999f ? "PASSED" : "FAILED"); + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); + free(h_q); free(h_k); free(h_v); free(h_o); + return cs > 0.999f ? 0 : 1; +} diff --git a/tests/unit/test_fmha_6warp_hd128.cu b/tests/unit/test_fmha_6warp_hd128.cu new file mode 100644 index 00000000..5481eda5 --- /dev/null +++ b/tests/unit/test_fmha_6warp_hd128.cu @@ -0,0 +1,2 @@ +#define HD_VAL 128 +#include "test_fmha_6warp.cu" diff --git a/tests/unit/test_fmha_6warp_hd16.cu b/tests/unit/test_fmha_6warp_hd16.cu new file mode 100644 index 00000000..affbee88 --- /dev/null +++ b/tests/unit/test_fmha_6warp_hd16.cu @@ -0,0 +1,2 @@ +#define HD_VAL 16 +#include "test_fmha_6warp.cu" diff --git a/tests/unit/test_fmha_6warp_hd256.cu b/tests/unit/test_fmha_6warp_hd256.cu new file mode 100644 index 00000000..812db34d --- /dev/null +++ b/tests/unit/test_fmha_6warp_hd256.cu @@ -0,0 +1,2 @@ +#define HD_VAL 256 +#include "test_fmha_6warp.cu" diff --git a/tests/unit/test_fmha_6warp_hd64.cu b/tests/unit/test_fmha_6warp_hd64.cu new file mode 100644 index 00000000..a5847981 --- /dev/null +++ b/tests/unit/test_fmha_6warp_hd64.cu @@ -0,0 +1,2 @@ +#define HD_VAL 64 +#include "test_fmha_6warp.cu"