diff --git a/dsv4/kernels/attention/fmha_6warp_tma.cuh b/dsv4/kernels/attention/fmha_6warp_tma.cuh new file mode 100644 index 00000000..b8b40618 --- /dev/null +++ b/dsv4/kernels/attention/fmha_6warp_tma.cuh @@ -0,0 +1,347 @@ +/** + * DSV4 FMHA — 6-warp specialized kernel, multi-row softmax, TMA async loads. + * + * ================================================================== + * DESIGN + * ================================================================== + * + * Same 6-warp design as fmha_6warp_multirow.cuh, but replaces scalar + * GMEM reads in the load warp with TMA async bulk copies. + * + * 6-warp CTA: warps 0-3 = softmax, warp 4 = MMA, warp 5 = TMA load. + * Grid: (1, n_h, batch) — each CTA processes one head of one batch item. + * + * TMA PIPELINE (single-stage, no overlap yet): + * For each K sub-tile (kt): + * 1. TMA warp issues cp.async.bulk.tensor.2d for Q sub-tile and K sub-tile + * 2. mbarrier wait for TMA completion (selp.b32 polling — @p bra HANGS!) + * 3. Load warp transposes row-major SMEM → canonical K-major SMEM + * 4. MMA warp runs tcgen05.mma as before + * + * KEY: Q is loaded per K-sub-tile, not once at the start. + * TMA tiles are always (128, 16) BF16 = 4KB — same for Q, K, V. + * ================================================================== + */ + +#pragma once + +#include "fmha_common.cuh" +#include "fmha_umma_desc.cuh" +#include "fmha_tma.cuh" + +namespace dsv4::kernels::attention { + +struct FmhaMultiRowTmaParams { + const bf16_t* __restrict__ q; + const bf16_t* __restrict__ k; + const bf16_t* __restrict__ v; + bf16_t* __restrict__ o; + float* __restrict__ lse; + int s_k, T; + float scale; + int head_dim; + int q_head_stride, q_batch_stride; + int k_head_stride, k_batch_stride; + int v_head_stride, v_batch_stride; + int o_head_stride, o_batch_stride; + int lse_head_stride, lse_batch_stride; + // TMA descriptors (device pointers to CUtensorMap in GMEM) + CUtensorMap* __restrict__ tma_q; // Q: (T, HD) — 2D BF16 with byte strides + CUtensorMap* __restrict__ tma_k; // K: (s_k, HD) + CUtensorMap* __restrict__ tma_v; // V: (HD, s_k) +}; + +template +__global__ void __launch_bounds__(192) +fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) { + static constexpr int NKT_QK = HD / MMA_K_BF16; + static constexpr int NKT_PV = SK_TILE / MMA_K_BF16; + static constexpr int N_NSUB = HD / 16; + static constexpr int TILE_SZ = 128 * MMA_K_BF16; + static constexpr int V_SUB_SZ = 16 * MMA_K_BF16; + static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + static constexpr int MAX_ROWS = 128; + static constexpr int CORES_MN = 128 / 8; + static constexpr int NUM_READS = SK_TILE / 8; + static constexpr int TMA_TILE_BF16 = 128 * MMA_K_BF16; + + const int head_idx = blockIdx.y; + const int batch_idx = blockIdx.z; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + const bool is_softmax_warp = (wid < 4); + const bool is_mma_warp = (wid == 4); + const bool is_load_warp = (wid == 5); + const int T = params.T; + const int s_k = params.s_k; + const float scale = params.scale; + + bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride; + float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr; + + CUtensorMap* __restrict__ tma_q = params.tma_q; + CUtensorMap* __restrict__ tma_k = params.tma_k; + CUtensorMap* __restrict__ tma_v = params.tma_v; + + // ================================================================== + // SMEM allocation + // ================================================================== + extern __shared__ char sbuf[]; + size_t off = 0; + + uint32_t* sTmemBase = (uint32_t*)sbuf; off = 4; + + off = (off + 127) & ~(size_t)127; + uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8; + + float* sRowMax = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float); + float* sRowSum = (float*)(sbuf + off); off += MAX_ROWS * sizeof(float); + + off = (off + 127) & ~(size_t)127; + bf16_t* sQ_tma = (bf16_t*)(sbuf + off); off += TMA_TILE_BF16 * sizeof(bf16_t); + + off = (off + 127) & ~(size_t)127; + bf16_t* sK_tma = (bf16_t*)(sbuf + off); off += TMA_TILE_BF16 * sizeof(bf16_t); + + off = (off + 127) & ~(size_t)127; + bf16_t* sQ = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + + off = (off + 127) & ~(size_t)127; + bf16_t* sK = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + + off = (off + 127) & ~(size_t)127; + bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + + off = (off + 127) & ~(size_t)127; + bf16_t* sV_tma = (bf16_t*)(sbuf + off); off += 16 * 128 * sizeof(bf16_t); + + off = (off + 127) & ~(size_t)127; + bf16_t* sV = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + + // ================================================================== + // Initialize + // ================================================================== + if (tid == 0) { + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_mbarrier_init(mbar_addr, 1); + } + if (is_mma_warp) { + uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); + tmem_alloc(smem_ptr, TMEM_N); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + + const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp; + const int my_row = my_warp_active ? (wid * 32 + lane) : 0; + const bool my_row_active = my_warp_active && (my_row < T); + + // ================================================================== + // QK GEMM → S in TMEM (loop over K sub-tiles) + // ================================================================== + for (int kt = 0; kt < NKT_QK; kt++) { + // --- TMA load Q sub-tile --- + if (tid == 0) { + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_mbarrier_init(mbar_addr, 1); + } + __syncthreads(); + + if (is_load_warp && lane == 0) { + uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sQ_tma); + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + // TMA 2D: coord {col_offset, row_offset} = {kt*16, 0} + tma_load_2d(smem_dst, (uint64_t)tma_q, mbar_addr, kt * MMA_K_BF16, 0); + } + if (is_load_warp && lane == 0) { + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_mbarrier_wait(mbar_addr); + } + __syncthreads(); + + if (is_load_warp) write_smem_canonical<128, MMA_K_BF16>(sQ, sQ_tma); + __syncthreads(); + + // --- TMA load K sub-tile --- + if (tid == 0) { + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_mbarrier_init(mbar_addr, 1); + } + __syncthreads(); + + if (is_load_warp && lane == 0) { + uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sK_tma); + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_load_2d(smem_dst, (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0); + } + if (is_load_warp && lane == 0) { + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_mbarrier_wait(mbar_addr); + } + __syncthreads(); + + if (is_load_warp) write_smem_canonical<128, MMA_K_BF16>(sK, sK_tma); + __syncthreads(); + + // MMA: sQ × sK → TMEM + if (is_mma_warp) { + uint32_t idesc = make_idesc(128, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // ================================================================== + // SOFTMAX (identical to non-TMA kernel) + // ================================================================== + float my_row_max = -INFINITY; + if (my_warp_active) { + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) my_row_max = fmaxf(my_row_max, tmp[c] * scale); + } + } + } + } + if (my_row_active) sRowMax[my_row] = my_row_max; + __syncthreads(); + + float my_p_vals[SK_TILE]; + float my_row_sum = 0.0f; + if (my_warp_active) { + float rm = my_row_active ? sRowMax[my_row] : 0.0f; + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) { + float p = expf(tmp[c] * scale - rm); + my_p_vals[col] = p; + my_row_sum += p; + } + } + } + } + } + if (my_row_active) sRowSum[my_row] = my_row_sum; + __syncthreads(); + + // ================================================================== + // PV GEMM + // ================================================================== + for (int n_sub = 0; n_sub < N_NSUB; n_sub++) { + int d_base = n_sub * 16; + for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) { + const int col_start = pv_kt * MMA_K_BF16; + + if (is_load_warp) for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0; + __syncthreads(); + + if (my_row_active) { + for (int c = 0; c < MMA_K_BF16; c++) { + int gc = col_start + c; + int ck = c/8, lc = c%8; + int core_mn = my_row/8, local_r = my_row%8; + sPk[ck*CORES_MN*64 + core_mn*64 + local_r*8 + lc] = f32_to_bf16(my_p_vals[gc]); + } + } + __syncthreads(); + + // --- TMA load V sub-tile --- + if (tid == 0) { + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_mbarrier_init(mbar_addr, 1); + } + __syncthreads(); + + if (is_load_warp && lane == 0) { + uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sV_tma); + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + // V is (HD, s_k). TMA 2D: coord {col_start, d_base} + // Descriptor: (s_k, HD) innermost-first. Loading sub-tile at (col_start, d_base) + tma_load_2d(smem_dst, (uint64_t)tma_v, mbar_addr, col_start, d_base); + } + if (is_load_warp && lane == 0) { + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + tma_mbarrier_wait(mbar_addr); + } + __syncthreads(); + + // Transpose sV_tma (16, 128) → sV (128, 16) canonical + if (is_load_warp) { + constexpr int SV_CORES_MN = 128 / 8; + for (int i = lane; i < TILE_SZ; i += 32) sV[i] = 0; + for (int i = lane; i < 16 * 128; i += 32) { + int d = i / 128, r = i % 128; + int core_mn = r / 8, local_r = r % 8; + int core_k = d / 8, local_c = d % 8; + int dst_idx = core_k * SV_CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; + sV[dst_idx] = sV_tma[i]; + } + } + __syncthreads(); + + if (is_mma_warp) { + uint32_t idesc_pv = make_idesc(128, 16); + uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128); + uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); + if (tid == 128) umma_ss_f16(tb + n_sub*16, dp, dv, idesc_pv, pv_kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + } + + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // ================================================================== + // EPILOGUE (identical to non-TMA kernel) + // ================================================================== + if (my_warp_active) { + float rm = my_row_active ? sRowMax[my_row] : 0.0f; + float rs = my_row_active ? sRowSum[my_row] : 0.0f; + float inv_rs = my_row_active ? (1.0f / rs) : 0.0f; + + for (int n = 0; n < N_NSUB * 2; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int d = n * 8 + c; + if (d < HD) o_head[my_row * HD + d] = f32_to_bf16(tmp[c] * inv_rs); + } + } + } + if (my_row_active && lse_head) lse_head[my_row] = logf(rs) + rm; + } + __syncthreads(); + if (is_mma_warp) tmem_dealloc(tb, TMEM_N); +} + +} // namespace dsv4::kernels::attention diff --git a/dsv4/kernels/attention/fmha_tma.cuh b/dsv4/kernels/attention/fmha_tma.cuh new file mode 100644 index 00000000..8025fc5d --- /dev/null +++ b/dsv4/kernels/attention/fmha_tma.cuh @@ -0,0 +1,167 @@ +/** + * DSV4 FMHA — TMA async load infrastructure for Blackwell SM100. + * + * ================================================================== + * CUDA 13 CRITICAL NOTES (verified on B200, driver 580.126.20) + * ================================================================== + * + * 1. globalStrides are in BYTES, not elements. + * Old (CUDA 12): gs[] = {1, cols} — element strides + * New (CUDA 13): gs[] = {cols*2, ...} — byte strides + * This was the root cause of cuTensorMapEncodeTiled returning + * INVALID_VALUE for 2D+ descriptors. + * + * 2. Use BFLOAT16 data type (CU_TENSOR_MAP_DATA_TYPE_BFLOAT16) + * instead of UINT16 for BF16 tensors. + * + * 3. mbarrier wait: the @p bra DONE pattern HANGS on SM100. + * Use the selp.b32 polling pattern instead. See tma_mbarrier_wait. + * + * 4. CUTLASS driver workaround: for driver <= 13.1 and total_bytes + * < 128KB, clear bit 21 of descriptor word[1]. See + * docs/cuda13_tma_notes.md for details. + * ================================================================== + */ + +#pragma once + +#include "fmha_common.cuh" +#include + +namespace dsv4::kernels::attention { + +// ================================================================== +// TMA descriptor helpers (host-side) +// ================================================================== + +/** + * Create a 2D TMA descriptor for a BF16 tensor of shape (rows, cols). + * + * IMPORTANT: CUDA 13 requires byte strides, not element strides. + * globalStrides[0] = cols * 2 (bytes, BF16 = 2 bytes) + * For rank=2, only 1 stride is needed (rank-1). + */ +inline bool create_tma_desc_2d_bf16( + CUtensorMap* out, + const void* gmem_ptr, + uint64_t rows, + uint64_t cols, + uint32_t tile_rows, + uint32_t tile_cols, + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE +) { + // 2D: (cols, rows) innermost-first + // Byte strides: stride[0] = cols * sizeof(bf16) = cols * 2 + uint64_t global_dim[] = {cols, rows}; + uint64_t global_str[] = {cols * 2}; // byte stride (CUDA 13) + uint32_t tile_dim[] = {tile_cols, tile_rows}; + uint32_t tile_str[] = {1, 1}; // element strides within tile + + CUresult res = cuTensorMapEncodeTiled( + out, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, + 2, + const_cast(gmem_ptr), + global_dim, global_str, tile_dim, tile_str, + CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + if (res != CUDA_SUCCESS) { + fprintf(stderr, "cuTensorMapEncodeTiled failed: error=%d, gdim=[%lu,%lu], gstr=[%lu], tdim=[%u,%u]\n", + (int)res, global_dim[0], global_dim[1], global_str[0], + tile_dim[0], tile_dim[1]); + return false; + } + + // CUTLASS driver workaround: clear bit 21 of desc[1] for driver <= 13.1 + small tensors + int driver_version = 0; + cudaDriverGetVersion(&driver_version); + size_t total_bytes = rows * cols * 2; // BF16 + if (driver_version <= 13010 && total_bytes < 131072) { + reinterpret_cast(out)[1] &= ~(1ULL << 21); + } + + return true; +} + +// ================================================================== +// TMA kernel-side operations +// ================================================================== + +/** + * Initialize an mbarrier in SMEM with expected transaction count. + * For TMA with complete_tx::bytes, pass the byte count of the transfer. + * For simplicity, can also use count=1. + */ +__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t expected_count = 1) { + asm volatile("mbarrier.init.shared.b64 [%0], %1;" + :: "r"(smem_mbar), "r"(expected_count)); +} + +/** + * Issue a 2D TMA async copy from GMEM to SMEM. + * + * @param smem_dst SMEM destination address (via __cvta_generic_to_shared) + * @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast) + * @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared) + * @param coord_x Column coordinate (innermost dimension) + * @param coord_y Row coordinate (outermost dimension) + */ +__device__ __forceinline__ void tma_load_2d( + uint32_t smem_dst, + uint64_t tma_desc, + uint32_t smem_mbar, + int coord_x, + int coord_y +) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%3, %4}], [%2];" + :: "r"(smem_dst), + "l"(tma_desc), + "r"(smem_mbar), + "r"(coord_x), + "r"(coord_y) + : "memory" + ); +} + +/** + * Wait for mbarrier completion (polling approach). + * + * CRITICAL: The @p bra DONE pattern HANGS on SM100! + * Use the selp.b32 approach to convert the predicate to an integer, + * then branch on the integer. This is the ONLY pattern that works + * reliably with mbarrier.try_wait.parity on Blackwell SM100. + */ +__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar) { + int phase = 0; + int done = 0; + for (int i = 0; i < 10000000; i++) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "mbarrier.try_wait.parity.shared.b64 p, [%0], %1;\n\t" + "selp.b32 %2, 1, 0, p;\n\t" + "}" + : "=r"(done) + : "r"(smem_mbar), "r"(phase) + : "memory" + ); + if (done) break; + } +} + +// ================================================================== +// TMA parameter structure +// ================================================================== + +struct FmhaTmaDescriptors { + CUtensorMap* __restrict__ tma_q; // Q descriptor: (T, HD) + CUtensorMap* __restrict__ tma_k; // K descriptor: (s_k, HD) + CUtensorMap* __restrict__ tma_v; // V descriptor: (HD, s_k) +}; + +} // namespace dsv4::kernels::attention diff --git a/tests/unit/test_fmha_tma.cu b/tests/unit/test_fmha_tma.cu new file mode 100644 index 00000000..3faab79a --- /dev/null +++ b/tests/unit/test_fmha_tma.cu @@ -0,0 +1,227 @@ +/** + * Test TMA async FMHA kernel (6-warp, multi-row, TMA loads). + * Compile with -DHD_VAL=64 etc. + * + * Uses CUDA 13 TMA descriptors with byte strides and BFLOAT16 data type. + * mbarrier wait uses selp.b32 polling (@p bra HANGS on SM100). + */ + +#include +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" +#include "dsv4/kernels/attention/fmha_tma.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int MAX_T = 128; + +#include "dsv4/kernels/attention/fmha_6warp_tma.cuh" + +static int compute_smem_tma() { + size_t off = 0; + off += 4; // sTmemBase + off = (off + 127) & ~(size_t)127; + off += 8; // sMbar + off += MAX_T * sizeof(float); // sRowMax + off += MAX_T * sizeof(float); // sRowSum + off = (off + 127) & ~(size_t)127; + off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sQ_tma + off = (off + 127) & ~(size_t)127; + off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sK_tma + off = (off + 127) & ~(size_t)127; + off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sQ canonical + off = (off + 127) & ~(size_t)127; + off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sK canonical + off = (off + 127) & ~(size_t)127; + off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sPk canonical + off = (off + 127) & ~(size_t)127; + off += 16 * 128 * sizeof(bf16_t); // sV_tma + off = (off + 127) & ~(size_t)127; + off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sV canonical + return (int)off; +} + +static void reference_attention_multirow( + const bf16_t* q, const bf16_t* k, const bf16_t* v, + float* o_ref, float* lse_ref, + int hd, int T, int s_k, float scale +) { + for (int t = 0; t < T; t++) { + float s[512]; + for (int j = 0; j < s_k; j++) { + float dot = 0.0f; + for (int d = 0; d < hd; d++) + dot += bf16_to_f32_host(q[t * hd + d]) * bf16_to_f32_host(k[j * hd + d]); + s[j] = dot * scale; + } + float mx = -INFINITY; + for (int j = 0; j < s_k; j++) mx = fmaxf(mx, s[j]); + float sm = 0.0f; + for (int j = 0; j < s_k; j++) { s[j] = expf(s[j] - mx); sm += s[j]; } + for (int j = 0; j < s_k; j++) s[j] /= sm; + for (int d = 0; d < hd; d++) { + float ov = 0.0f; + for (int j = 0; j < s_k; j++) ov += s[j] * bf16_to_f32_host(v[d * s_k + j]); + o_ref[t * hd + d] = ov; + } + if (lse_ref) lse_ref[t] = logf(sm) + mx; + } +} + +struct TmaDescSet { + CUtensorMap tma_q, tma_k, tma_v; + CUtensorMap *d_tma_q, *d_tma_k, *d_tma_v; + + bool create(bf16_t* d_q, bf16_t* d_k, bf16_t* d_v, + int T, int hd, int s_k) { + // Q: (128, HD) padded, TMA tile = (128, 16) + if (!create_tma_desc_2d_bf16(&tma_q, d_q, 128, (uint64_t)hd, 128, 16)) { + printf(" Q TMA desc FAILED\n"); return false; + } + // K: (s_k, HD), TMA tile = (128, 16) + if (!create_tma_desc_2d_bf16(&tma_k, d_k, (uint64_t)s_k, (uint64_t)hd, 128, 16)) { + printf(" K TMA desc FAILED\n"); return false; + } + // V: (HD, s_k), TMA tile = (16, 128) + // V innermost dim = s_k, tile = (128, 16) means tile_cols=128, tile_rows=16 + if (!create_tma_desc_2d_bf16(&tma_v, d_v, (uint64_t)hd, (uint64_t)s_k, 16, 128)) { + printf(" V TMA desc FAILED\n"); return false; + } + + cudaMalloc(&d_tma_q, sizeof(CUtensorMap)); + cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); + cudaMalloc(&d_tma_v, sizeof(CUtensorMap)); + cudaMemcpy(d_tma_q, &tma_q, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + cudaMemcpy(d_tma_v, &tma_v, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + return true; + } + + void destroy() { + if (d_tma_q) { cudaFree(d_tma_q); d_tma_q = nullptr; } + if (d_tma_k) { cudaFree(d_tma_k); d_tma_k = nullptr; } + if (d_tma_v) { cudaFree(d_tma_v); d_tma_v = nullptr; } + } +}; + +static int test_single(int T, int n_h = 1, int batch = 1) { + printf("\n=== TMA T=%d, n_h=%d, batch=%d, HD=%d ===\n", T, n_h, batch, HD); + const float SCALE = 1.0f / sqrtf((float)HD); + int total_heads = batch * n_h; + constexpr int Q_PAD_ROWS = 128; + + bf16_t* h_q = (bf16_t*)calloc(total_heads * Q_PAD_ROWS * HD, sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)malloc(total_heads * SK * HD * sizeof(bf16_t)); + bf16_t* h_v = (bf16_t*)malloc(total_heads * HD * SK * sizeof(bf16_t)); + bf16_t* h_o = (bf16_t*)calloc(total_heads * MAX_T * HD, sizeof(bf16_t)); + float* h_lse = (float*)calloc(total_heads * MAX_T, sizeof(float)); + + srand(42 + T); + for (int i = 0; i < total_heads * T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < total_heads * SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < total_heads * HD * SK; i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_q, *d_k, *d_v, *d_o; float *d_lse; + cudaMalloc(&d_q, total_heads * Q_PAD_ROWS * HD * sizeof(bf16_t)); + cudaMalloc(&d_k, total_heads * SK * HD * sizeof(bf16_t)); + cudaMalloc(&d_v, total_heads * HD * SK * sizeof(bf16_t)); + cudaMalloc(&d_o, total_heads * MAX_T * HD * sizeof(bf16_t)); + cudaMalloc(&d_lse, total_heads * MAX_T * sizeof(float)); + cudaMemcpy(d_q, h_q, total_heads * Q_PAD_ROWS * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k, total_heads * SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v, total_heads * HD * SK * sizeof(bf16_t), cudaMemcpyHostToDevice); + + int failed = 0; + float min_cos = 1.0f; + + for (int b = 0; b < batch; b++) { + for (int h = 0; h < n_h; h++) { + int idx = b * n_h + h; + TmaDescSet tma; + bf16_t* d_q_h = d_q + idx * Q_PAD_ROWS * HD; + bf16_t* d_k_h = d_k + idx * SK * HD; + bf16_t* d_v_h = d_v + idx * HD * SK; + if (!tma.create(d_q_h, d_k_h, d_v_h, T, HD, SK)) { + failed++; continue; + } + + FmhaMultiRowTmaParams params; + params.q = d_q_h; params.k = d_k_h; params.v = d_v_h; + params.o = d_o + idx * MAX_T * HD; params.lse = d_lse + idx * MAX_T; + params.s_k = SK; params.T = T; params.scale = SCALE; params.head_dim = HD; + params.q_head_stride = Q_PAD_ROWS * HD; params.q_batch_stride = n_h * Q_PAD_ROWS * HD; + params.k_head_stride = SK * HD; params.k_batch_stride = n_h * SK * HD; + params.v_head_stride = HD * SK; params.v_batch_stride = n_h * HD * SK; + params.o_head_stride = MAX_T * HD; params.o_batch_stride = n_h * MAX_T * HD; + params.lse_head_stride = MAX_T; params.lse_batch_stride = n_h * MAX_T; + params.tma_q = tma.d_tma_q; params.tma_k = tma.d_tma_k; params.tma_v = tma.d_tma_v; + + int smem = compute_smem_tma(); + if (smem > 48 * 1024) + cudaFuncSetAttribute(fmha_6warp_tma_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + fmha_6warp_tma_kernel<<>>(params); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf(" CUDA ERROR b=%d h=%d: %s\n", b, h, cudaGetErrorString(err)); + failed++; tma.destroy(); continue; + } + + bf16_t* h_o_head = (bf16_t*)malloc(T * HD * sizeof(bf16_t)); + float* h_lse_head = (float*)malloc(T * sizeof(float)); + cudaMemcpy(h_o_head, d_o + idx * MAX_T * HD, T * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_lse_head, d_lse + idx * MAX_T, T * sizeof(float), cudaMemcpyDeviceToHost); + + float o_ref[MAX_T * 512]; float lse_ref[MAX_T]; + reference_attention_multirow(h_q + idx * Q_PAD_ROWS * HD, h_k + idx * SK * HD, h_v + idx * HD * SK, o_ref, lse_ref, HD, T, SK, SCALE); + + for (int t = 0; t < T; t++) { + float cs=0,na=0,nb=0; + for (int d=0;d1e-4f){cs+=a*b2;na+=a*a;nb+=b2*b2;} + } + cs /= (sqrtf(na)*sqrtf(nb)+1e-10f); + if(cs