From 414b3f4f92cf24d28041f5f1f3c38eeae26688ef Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 14:13:11 +0000 Subject: [PATCH] Full FMHA HD=16 with PV GEMM via tcgen05.mma TS --- tests/unit/test_fmha_ts_full.cu | 244 ++++++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 tests/unit/test_fmha_ts_full.cu diff --git a/tests/unit/test_fmha_ts_full.cu b/tests/unit/test_fmha_ts_full.cu new file mode 100644 index 00000000..bd2530bc --- /dev/null +++ b/tests/unit/test_fmha_ts_full.cu @@ -0,0 +1,244 @@ +/** + * Full FMHA HD=16, SK=128 — with PV GEMM via tcgen05.mma TS + * + * Pipeline: Q×K^T (SS) → softmax (TMEM) → P×V (TS) → epilogue + * + * Key design: + * - TMEM: 256 columns. Cols 0-127 = S/P. Cols 128-143 = O. + * - QK: SS MMA (1 K-tile at HD=16) + * - Softmax: read S from TMEM, compute P, write P back to TMEM + * - PV: TS MMA, 8 K-tiles (SK=128 → SK/16 = 8 K-tiles) + * - Epilogue: read O from TMEM, apply MMA scale (2×), write to GMEM + */ + +#include +#include +#include +#include +#include + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = 16, SK = 128, BLOCK_MN = 128; +constexpr int NKT_QK = HD / MMA_K_BF16; // 1 K-tile for QK at HD=16 +constexpr int NKT_PV = SK / MMA_K_BF16; // 8 K-tiles for PV at SK=128 +constexpr int TMEM_P = 128; +constexpr int TMEM_O = 32; // 16 cols → round to 32 (power of 2, min 32) +constexpr int TMEM_N = 256; +constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; // 2048 BF16 per K-tile + +__global__ void __launch_bounds__(128) +test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, + const bf16_t* __restrict__ v, bf16_t* __restrict__ o_out, + float* __restrict__ o_scalar, float scale) +{ + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + + // SMEM: tmem_base + sQ(128,16) + sK(128,16) + sV(8 × 16×16) + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + TILE_SZ; // 2048 BF16 per K-tile + bf16_t* sV = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); + + // Load Q: (1, HD) → (128, 16) canonical + write_q_to_smem(sQ0, q); + // Load K: (SK, HD) → (128, 16) canonical + write_k_to_smem(sK0, k); + + // Load V: (HD, SK) → 8 K-tiles of (16, 16) canonical + // For PV: B = (SK, HD) = (128, 16). The kt-th K-tile: B[16*kt..16*kt+16, 0..16) + // B[r, d] = V[d, r] = V[d * SK + r] + for (int kt = 0; kt < NKT_PV; kt++) { + bf16_t* sv = sV + kt * 256; // 256 BF16 = 512 bytes per (16,16) tile + for (int i = tid; i < 256; i += 128) sv[i] = 0; + for (int d = tid; d < HD; d += 128) { + for (int lr = 0; lr < MMA_K_BF16; lr++) { + int r = kt * MMA_K_BF16 + lr; + int ck = d / 8, lc = d % 8; + int tmn = lr / 8, llr = lr % 8; + int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc; + sv[dst_idx] = v[d * SK + r]; + } + } + } + __syncthreads(); + + // TMEM alloc + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + __syncthreads(); + uint32_t tb = *sTmemBase; + uint32_t tb_o = tb + TMEM_P; // O at column 128 + + // ===== STEP 1: QK GEMM (SS) ===== + { + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN); + uint32_t idesc_qk = make_idesc(BLOCK_MN, BLOCK_MN); + for (int kt = 0; kt < NKT_QK; kt++) { + if (tid == 0) umma_ss_f16(tb, dq, dk, idesc_qk, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + + // ===== STEP 2: Softmax ===== + if (wid == 0) { + float s_vals[SK], row_max = -INFINITY; + for (int n = 0; n < SK / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) { + s_vals[n*8+c] = tmp[c] * scale; + row_max = fmaxf(row_max, tmp[c] * scale); + } + } + row_max = wmax(row_max); + float row_sum = 0.0f; + if (lane == 0) for (int j=0;j 0); + if (tid == 0) umma_ts_f16(tb_o, tmem_a, dv, idesc_pv, accumulate); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + + // ===== STEP 4: Epilogue — read O from TMEM ===== + // MMA output is scaled by 0.5, so multiply by 2.0 + if (wid == 0) { + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb_o + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f; + } + if (lane == 0) for (int d=0;d>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost); + + printf("O[0..15] MMA: "); for(int d=0;d0 ? max_diff/max_val : max_diff; + float cos_sim=0,na=0,nb=0; + for (int d=0;d 0.999f ? "PASSED" : "FAILED"); + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_o_scalar); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_o_scalar); + return cos_sim > 0.999f ? 0 : 1; +}