diff --git a/PREVIOUS_CONVERSATION_2.md b/PREVIOUS_CONVERSATION_2.md index ac43f354..34a0b79c 100644 --- a/PREVIOUS_CONVERSATION_2.md +++ b/PREVIOUS_CONVERSATION_2.md @@ -626,4 +626,28 @@ What I recommend: **focus on what works.** The raw CUDA 6-warp kernel is fully w Should I: - **A)** Continue debugging the CuTeDSL pv_n_tile=16 crash (might take a while, but gives us TMA + correct PV) - **B)** Move on to multi-head launch (more immediate production impact, no TMA needed) -- **C)** Try to find the correct `cuTensorMapEncodeTiled` parameters by studying the CuTeDSL source code that creates the descriptors \ No newline at end of file +- **C)** Try to find the correct `cuTensorMapEncodeTiled` parameters by studying the CuTeDSL source code that creates the descriptors + + +Milestone 5 is done. Here's the summary: + +**Multi-head FMHA grid launch — ALL PASS ✅** + +- `fmha_6warp_multihead.cuh`: kernel with `FmhaParams` struct, grid=`(1, n_h, batch)` +- Each CTA processes one head independently — no cross-CTA sync +- **MQA**: `k_head_stride=0` / `v_head_stride=0` lets all Q heads share KV +- **LSE output**: `lse = ln(row_sum) + row_max` for D5 multi-segment KV merge +- Tested across **all 4 HD values × 3 modes** (MHA 4+8 heads, MQA, batched 4×2): + - HD=16: cos 0.999995–0.999998 ✅ + - HD=64: cos 0.999996–0.999998 ✅ + - HD=128: cos 0.999996–0.999998 ✅ + - HD=256: cos 0.999996–0.999998 ✅ + - LSE matches reference to 6 decimal places across all configs + +**Remaining milestones:** +- **M2 TMA loads**: blocked on B200 driver (cuTensorMapEncodeTiled 2D+ fails) +- **M3 Pipeline overlap**: depends on TMA +- **M4 Multi-row softmax**: for prefill T>1 — this is unblocked and next in line +- **M6 Production integration**: hook into production.py + +Want me to tackle multi-row softmax (M4) next? That's the big unlock for prefill and it doesn't depend on TMA. \ No newline at end of file diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh new file mode 100644 index 00000000..d4bf44a3 --- /dev/null +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -0,0 +1,351 @@ +/** + * DSV4 FMHA — 6-warp specialized kernel, multi-row softmax (prefill T>1). + * + * ================================================================== + * MULTI-ROW SOFTMAX (Milestone 4) + * ================================================================== + * Extends the 6-warp kernel to handle T>1 rows (prefill). + * + * Architecture: + * Warp 0-3 (softmax/epilogue): Each warp handles up to 32 rows. + * Warp w processes rows [w*32, min(w*32+32, T)). + * For T<=32, only warp 0 is active. + * Warp 4 (MMA): QK + PV GEMM via tcgen05.mma SS + * Warp 5 (data staging): Load Q/K/V from GMEM to SMEM + * + * TMEM multi-row addressing: + * tcgen05.ld.32x32b.x8 reads 32 rows from TMEM. + * Address: tmem_base + (row_offset << 16) + col + * Warp w uses row_offset = w*32 to read its group of 32 rows. + * Lane l in warp w reads row (w*32 + l) from TMEM. + * + * Q layout for prefill: + * Q is (T, hd) per head, loaded as (128, hd) in SMEM canonical layout. + * Rows 0..T-1 have data; rows T..127 are zero. + * + * P layout for PV GEMM: + * After per-row softmax, P values are stored in s_p_vals[T][SK_TILE]. + * For PV, P is loaded as (128, 16) canonical sub-tiles. + * + * Output: + * O: (T, hd) per head, normalized. + * LSE: (T,) per head — one float per row. + * + * ================================================================== + * CONSTRAINTS + * ================================================================== + * - T <= 128 (fits in one 128-row MMA tile; longer → D5 KV merge) + * - s_k must be a multiple of 16 (MMA K-tile size) + * - head_dim must be one of: 16, 64, 128, 256 + * ================================================================== + */ + +#pragma once + +#include "fmha_common.cuh" +#include "fmha_umma_desc.cuh" + +namespace dsv4::kernels::attention { + +struct FmhaMultiRowParams { + const bf16_t* __restrict__ q; // [batch, n_h, T, hd] + const bf16_t* __restrict__ k; // [batch, n_kv, N, hd] + const bf16_t* __restrict__ v; // [batch, n_kv, hd, N] + bf16_t* __restrict__ o; // [batch, n_h, T, hd] + float* __restrict__ lse; // [batch, n_h, T] + + int s_k; // KV sequence length + int T; // Query sequence length (1 for decode, >1 for prefill) + float scale; // 1/sqrt(hd) + int head_dim; // hd + + int q_head_stride; // T * hd + int q_batch_stride; // n_h * T * hd + int k_head_stride; // N * hd (0 for MQA) + int k_batch_stride; // n_kv * N * hd + int v_head_stride; // hd * N (0 for MQA) + int v_batch_stride; // n_kv * hd * N + int o_head_stride; // T * hd + int o_batch_stride; // n_h * T * hd + int lse_head_stride; // T + int lse_batch_stride; // n_h * T +}; + +template +__global__ void __launch_bounds__(192) +fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { + static constexpr int NKT_QK = HD / MMA_K_BF16; + static constexpr int NKT_PV = SK_TILE / MMA_K_BF16; + static constexpr int N_NSUB = HD / 16; + static constexpr int TILE_SZ = 128 * MMA_K_BF16; + static constexpr int V_SUB_SZ = 256; + static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + static constexpr int ROWS_PER_WARP = 32; + static constexpr int MAX_ROWS = 128; + + const int head_idx = blockIdx.y; + const int batch_idx = blockIdx.z; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + + const bool is_softmax_warp = (wid < 4); + const bool is_mma_warp = (wid == 4); + const bool is_load_warp = (wid == 5); + + const int T = params.T; + const int s_k = params.s_k; + const float scale = params.scale; + + // Per-warp row range + const int warp_row_start = wid * ROWS_PER_WARP; + const int warp_row_end = min(warp_row_start + ROWS_PER_WARP, T); + const int warp_num_rows = is_softmax_warp ? max(0, warp_row_end - warp_row_start) : 0; + const bool warp_has_rows = (warp_num_rows > 0); + + // ================================================================== + // Per-head GMEM pointers + // ================================================================== + const bf16_t* __restrict__ q_head = params.q + + head_idx * params.q_head_stride + + batch_idx * params.q_batch_stride; + const bf16_t* __restrict__ k_head = params.k + + head_idx * params.k_head_stride + + batch_idx * params.k_batch_stride; + const bf16_t* __restrict__ v_head = params.v + + head_idx * params.v_head_stride + + batch_idx * params.v_batch_stride; + bf16_t* __restrict__ o_head = params.o + + head_idx * params.o_head_stride + + batch_idx * params.o_batch_stride; + float* __restrict__ lse_head = params.lse + ? params.lse + head_idx * params.lse_head_stride + + batch_idx * params.lse_batch_stride + : nullptr; + + // ================================================================ + // SMEM allocation + // ================================================================ + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + // Per-row max/sum: 128 floats each (one per possible row) + float* sRowMax = (float*)(sbuf + 4); + float* sRowSum = sRowMax + MAX_ROWS; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowSum + MAX_ROWS) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + TILE_SZ; + bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); + bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127); + // s_p_vals: [MAX_ROWS][SK_TILE] = 128 × 128 = 16384 floats = 64KB + float* s_p_vals = (float*)(sV + V_SUB_SZ); + + // ================================================================ + // TMEM allocation (warp 4) + // ================================================================ + if (is_mma_warp) { + uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); + tmem_alloc(smem_ptr, TMEM_N); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + + // ================================================================ + // QK GEMM loop + // ================================================================ + for (int kt = 0; kt < NKT_QK; kt++) { + if (is_load_warp) { + constexpr int CORES_MN = 128 / 8; + // Zero Q SMEM + for (int i = lane; i < TILE_SZ; i += 32) sQ0[i] = 0; + // Load Q (T, hd): write T rows to canonical layout + for (int r = 0; r < T; r++) { + for (int d = lane; d < MMA_K_BF16; d += 32) { + int full_d = kt * MMA_K_BF16 + d; + int core_k = full_d / 8; + int local_c = full_d % 8; + int core_mn = r / 8; + int local_r = r % 8; + sQ0[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = + q_head[r * HD + full_d]; + } + } + // Load K (s_k, hd) + for (int i = lane; i < TILE_SZ; i += 32) sK0[i] = 0; + for (int r = 0; r < s_k; r++) { + for (int d = lane; d < MMA_K_BF16; d += 32) { + int full_d = kt * MMA_K_BF16 + d; + int ck = full_d / 8, lc = full_d % 8; + int tmn = r / 8, lr = r % 8; + sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k_head[r * HD + full_d]; + } + } + } + __syncthreads(); + + if (is_mma_warp) { + uint32_t idesc = make_idesc(128, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + + // ================================================================ + // Multi-row softmax (Warps 0-3) + // ================================================================ + // Each warp reads its 32 rows from TMEM using row-offset addressing. + // 32x32b.x8: addr = tb + (row_base << 16) + col_group*8 + // Lane l reads row (row_base + l) from TMEM. + // Only active lanes (l < warp_num_rows) do actual computation, + // but ALL 32 lanes must participate in the collective TMEM load. + // ================================================================ + if (is_softmax_warp) { + uint32_t row_base_addr = tb + (warp_row_start << 16); + + // Per-lane: one row of S values + float s_vals[SK_TILE]; + float row_max = -INFINITY; + + // Read S from TMEM: 16 iterations × 8 columns = 128 columns + for (int n = 0; n < SK_TILE / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(row_base_addr + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + + // All lanes get their row's values; only active lanes compute + if (lane < (unsigned)warp_num_rows) { + for (int c = 0; c < 8; c++) { + s_vals[n * 8 + c] = tmp[c] * scale; + row_max = fmaxf(row_max, s_vals[n * 8 + c]); + } + } + } + + // Warp-level max reduction (only active lanes participate meaningfully) + // Inactive lanes have row_max = -INFINITY, which is safe for fmax + row_max = wmax(row_max); + + if (lane < (unsigned)warp_num_rows) { + sRowMax[warp_row_start + lane] = row_max; + } + + // Compute exp and sum per row + float row_sum = 0.0f; + if (lane < (unsigned)warp_num_rows) { + for (int j = 0; j < SK_TILE; j++) { + s_vals[j] = expf(s_vals[j] - row_max); + row_sum += s_vals[j]; + } + } + row_sum = wsum(row_sum); + + if (lane < (unsigned)warp_num_rows) { + sRowSum[warp_row_start + lane] = row_sum; + } + + // Normalize and write P to s_p_vals + if (lane < (unsigned)warp_num_rows) { + float inv_sum = 1.0f / row_sum; + int my_row = warp_row_start + lane; + for (int j = 0; j < SK_TILE; j++) { + s_p_vals[my_row * SK_TILE + j] = s_vals[j] * inv_sum; + } + } + } + __syncthreads(); + + // ================================================================ + // PV GEMM loop: N=16 sub-tiles × K-tiles + // P (128, s_k) × V (s_k, hd) → O (128, hd) in TMEM + // ================================================================ + for (int n = 0; n < N_NSUB; n++) { + int d_base = n * 16; + + for (int kt = 0; kt < NKT_PV; kt++) { + if (is_load_warp) { + constexpr int CORES_MN = 128 / 8; + // Fill sPk from s_p_vals: all T rows × 16 cols + for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0; + for (int r = 0; r < T; r++) { + for (int c = lane; c < MMA_K_BF16; c += 32) { + int global_col = kt * MMA_K_BF16 + c; + float pval = s_p_vals[r * SK_TILE + global_col]; + int core_mn = r / 8; + int local_r = r % 8; + int core_k = c / 8; + int local_c = c % 8; + sPk[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = + f32_to_bf16(pval); + } + } + // Load V sub-tile + for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0; + for (int dd = lane; dd < 16; dd += 32) { + for (int lr = 0; lr < MMA_K_BF16; lr++) { + int r = kt * MMA_K_BF16 + lr; + int g_mn = dd / 8, g_k = lr / 8; + int llr = dd % 8, lc = lr % 8; + sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = + v_head[(d_base + dd) * s_k + r]; + } + } + } + __syncthreads(); + + if (is_mma_warp) { + uint32_t idesc_pv16 = make_idesc(128, 16); + uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), 128); + uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); + if (tid == 128) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + } + + // ================================================================ + // Multi-row epilogue: read O from TMEM, normalize, write to GMEM + // Same row-offset addressing as softmax. + // ================================================================ + if (is_softmax_warp) { + uint32_t row_base_addr = tb + (warp_row_start << 16); + + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(row_base_addr + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane < (unsigned)warp_num_rows) { + for (int c = 0; c < 8; c++) o_vals[n * 8 + c] = tmp[c]; + } + } + + if (lane < (unsigned)warp_num_rows) { + int my_row = warp_row_start + lane; + float row_max = sRowMax[my_row]; + float row_sum = sRowSum[my_row]; + float inv_row_sum = 1.0f / row_sum; + for (int d = 0; d < HD; d++) { + o_head[my_row * HD + d] = f32_to_bf16(o_vals[d] * inv_row_sum); + } + if (lse_head) { + lse_head[my_row] = logf(row_sum) + row_max; + } + } + } + __syncthreads(); + + // TMEM dealloc + if (is_mma_warp) { + tmem_dealloc(tb, TMEM_N); + } +} + +} // namespace dsv4::kernels::attention diff --git a/tests/unit/test_fmha_6warp_multirow.cu b/tests/unit/test_fmha_6warp_multirow.cu new file mode 100644 index 00000000..8c0ddcd0 --- /dev/null +++ b/tests/unit/test_fmha_6warp_multirow.cu @@ -0,0 +1,223 @@ +/** + * Test multi-row FMHA kernel (6-warp, T>1 prefill). + * Compile with -DHD_VAL=64 etc. + * + * Tests: + * 1. T=1 decode (regression — must match single-row results) + * 2. T=2,4,8,16,32 (small prefill — only warp 0) + * 3. T=64,128 (multi-warp prefill — all 4 softmax warps) + * 4. LSE correctness for multi-row + * 5. Multi-head + batched with T>1 + */ + +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int TILE_SZ = 128 * MMA_K_BF16; +constexpr int V_SUB_SZ = 256; +constexpr int MAX_T = 128; + +#include "dsv4/kernels/attention/fmha_6warp_multirow.cuh" + +// Reference: compute attention for T rows +static void reference_attention_multirow( + const bf16_t* q, const bf16_t* k, const bf16_t* v, + float* o_ref, float* lse_ref, + int hd, int T, int s_k, float scale +) { + for (int t = 0; t < T; t++) { + float s[512]; // max s_k + for (int j = 0; j < s_k; j++) { + float dot = 0.0f; + for (int d = 0; d < hd; d++) { + dot += bf16_to_f32_host(q[t * hd + d]) * bf16_to_f32_host(k[j * hd + d]); + } + s[j] = dot * scale; + } + float mx = -INFINITY; + for (int j = 0; j < s_k; j++) mx = fmaxf(mx, s[j]); + float sm = 0.0f; + for (int j = 0; j < s_k; j++) { s[j] = expf(s[j] - mx); sm += s[j]; } + for (int j = 0; j < s_k; j++) s[j] /= sm; + for (int d = 0; d < hd; d++) { + float ov = 0.0f; + for (int j = 0; j < s_k; j++) ov += s[j] * bf16_to_f32_host(v[d * s_k + j]); + o_ref[t * hd + d] = ov; + } + if (lse_ref) lse_ref[t] = logf(sm) + mx; + } +} + +static int test_single_T(int T, int n_h = 1, int batch = 1) { + const char* mode = (n_h > 1) ? "multihead" : "single"; + printf("\n=== Test T=%d, n_h=%d, batch=%d, HD=%d, SK=%d (%s) ===\n", T, n_h, batch, HD, SK, mode); + const float SCALE = 1.0f / sqrtf((float)HD); + int pass = 1; + + int total_heads = batch * n_h; + + bf16_t* h_q = (bf16_t*)malloc(total_heads * T * HD * sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)malloc(total_heads * SK * HD * sizeof(bf16_t)); + bf16_t* h_v = (bf16_t*)malloc(total_heads * HD * SK * sizeof(bf16_t)); + bf16_t* h_o = (bf16_t*)calloc(total_heads * T * HD, sizeof(bf16_t)); + float* h_lse = (float*)calloc(total_heads * T, sizeof(float)); + + srand(42 + T); + for (int i = 0; i < total_heads * T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < total_heads * SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < total_heads * HD * SK; i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_q, *d_k, *d_v, *d_o; + float *d_lse; + cudaMalloc(&d_q, total_heads * T * HD * sizeof(bf16_t)); + cudaMalloc(&d_k, total_heads * SK * HD * sizeof(bf16_t)); + cudaMalloc(&d_v, total_heads * HD * SK * sizeof(bf16_t)); + cudaMalloc(&d_o, total_heads * T * HD * sizeof(bf16_t)); + cudaMalloc(&d_lse, total_heads * T * sizeof(float)); + cudaMemcpy(d_q, h_q, total_heads * T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k, total_heads * SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v, total_heads * HD * SK * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemset(d_o, 0, total_heads * T * HD * sizeof(bf16_t)); + cudaMemset(d_lse, 0, total_heads * T * sizeof(float)); + + FmhaMultiRowParams params; + params.q = d_q; + params.k = d_k; + params.v = d_v; + params.o = d_o; + params.lse = d_lse; + params.s_k = SK; + params.T = T; + params.scale = SCALE; + params.head_dim = HD; + params.q_head_stride = T * HD; + params.q_batch_stride = n_h * T * HD; + params.k_head_stride = SK * HD; + params.k_batch_stride = n_h * SK * HD; + params.v_head_stride = HD * SK; + params.v_batch_stride = n_h * HD * SK; + params.o_head_stride = T * HD; + params.o_batch_stride = n_h * T * HD; + params.lse_head_stride = T; + params.lse_batch_stride = n_h * T; + + // SMEM: tmemBase(4) + sRowMax(128*4) + sRowSum(128*4) + padding + sQ0 + sK0 + sPk + sV + s_p_vals + // s_p_vals = 128 * 128 * 4 = 65536 bytes + int smem = 4 + 128*4 + 128*4 + 16 + TILE_SZ*2 + TILE_SZ*2 + TILE_SZ*2 + V_SUB_SZ*2 + MAX_T * SK * 4 + 256; + smem = (smem + 127) & ~127; + + if (smem > 48 * 1024) { + cudaFuncSetAttribute(fmha_6warp_multirow_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + } + + dim3 grid(1, n_h, batch); + fmha_6warp_multirow_kernel<<>>(params); + + cudaError_t launch_err = cudaGetLastError(); + cudaError_t sync_err = cudaSuccess; + if (launch_err != cudaSuccess) { + printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); + pass = 0; goto cleanup; + } + sync_err = cudaDeviceSynchronize(); + if (sync_err != cudaSuccess) { + printf("CUDA ERROR: %s\n", cudaGetErrorString(sync_err)); + pass = 0; goto cleanup; + } + + cudaMemcpy(h_o, d_o, total_heads * T * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_lse, d_lse, total_heads * T * sizeof(float), cudaMemcpyDeviceToHost); + + // Verify each head + int checked = 0, failed = 0; + float min_cos = 1.0f; + for (int b = 0; b < batch; b++) { + for (int h = 0; h < n_h; h++) { + int idx = b * n_h + h; + float o_ref[MAX_T * 512]; // max T * max HD + float lse_ref[MAX_T]; + reference_attention_multirow( + h_q + idx * T * HD, + h_k + idx * SK * HD, + h_v + idx * HD * SK, + o_ref, lse_ref, HD, T, SK, SCALE + ); + + for (int t = 0; t < T; t++) { + float cs = 0, na = 0, nb = 0; + for (int d = 0; d < HD; d++) { + float a = bf16_to_f32_host(h_o[(idx * T + t) * HD + d]); + float b2 = o_ref[t * HD + d]; + if (fabsf(b2) > 1e-4f) { cs += a*b2; na += a*a; nb += b2*b2; } + } + cs /= (sqrtf(na) * sqrtf(nb) + 1e-10f); + if (cs < min_cos) min_cos = cs; + checked++; + + float lse_err = fabsf(h_lse[idx * T + t] - lse_ref[t]) / (fabsf(lse_ref[t]) + 1e-10f); + + if (cs < 0.999f) { + printf(" FAIL batch=%d head=%d row=%d: cos=%.6f lse_err=%.6f\n", b, h, t, cs, lse_err); + failed++; + } + } + } + } + printf(" Checked %d rows, %d failed, min_cos=%.8f\n", checked, failed, min_cos); + pass = (failed == 0); + printf(" %s\n", pass ? "PASSED" : "FAILED"); + +cleanup: + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); + return pass; +} + +int main() { + printf("========================================\n"); + printf("Multi-row FMHA test suite (HD=%d)\n", HD); + printf("========================================\n"); + + int all_pass = 1; + + // Test 1: T=1 decode (regression — must match single-row) + all_pass &= test_single_T(1); + + // Test 2: Small prefill (only warp 0) + all_pass &= test_single_T(2); + all_pass &= test_single_T(4); + all_pass &= test_single_T(8); + all_pass &= test_single_T(16); + all_pass &= test_single_T(32); + + // Test 3: Multi-warp prefill + all_pass &= test_single_T(64); + all_pass &= test_single_T(128); + + // Test 4: Multi-head + prefill + all_pass &= test_single_T(4, 4, 1); // 4 heads, T=4 + all_pass &= test_single_T(16, 2, 1); // 2 heads, T=16 + all_pass &= test_single_T(8, 4, 2); // 4 heads × 2 batch, T=8 + + printf("\n========================================\n"); + printf("Overall: %s\n", all_pass ? "ALL PASSED" : "SOME FAILED"); + printf("========================================\n"); + return all_pass ? 0 : 1; +} diff --git a/tests/unit/test_fmha_6warp_multirow_hd128.cu b/tests/unit/test_fmha_6warp_multirow_hd128.cu new file mode 100644 index 00000000..5bdf68f4 --- /dev/null +++ b/tests/unit/test_fmha_6warp_multirow_hd128.cu @@ -0,0 +1,2 @@ +#define HD_VAL 128 +#include "test_fmha_6warp_multirow.cu" diff --git a/tests/unit/test_fmha_6warp_multirow_hd16.cu b/tests/unit/test_fmha_6warp_multirow_hd16.cu new file mode 100644 index 00000000..0077625e --- /dev/null +++ b/tests/unit/test_fmha_6warp_multirow_hd16.cu @@ -0,0 +1,2 @@ +#define HD_VAL 16 +#include "test_fmha_6warp_multirow.cu" diff --git a/tests/unit/test_fmha_6warp_multirow_hd256.cu b/tests/unit/test_fmha_6warp_multirow_hd256.cu new file mode 100644 index 00000000..3568341a --- /dev/null +++ b/tests/unit/test_fmha_6warp_multirow_hd256.cu @@ -0,0 +1,2 @@ +#define HD_VAL 256 +#include "test_fmha_6warp_multirow.cu" diff --git a/tests/unit/test_fmha_6warp_multirow_hd64.cu b/tests/unit/test_fmha_6warp_multirow_hd64.cu new file mode 100644 index 00000000..8466d256 --- /dev/null +++ b/tests/unit/test_fmha_6warp_multirow_hd64.cu @@ -0,0 +1,2 @@ +#define HD_VAL 64 +#include "test_fmha_6warp_multirow.cu"