diff --git a/NVFP4-1.1_RAW_CUDA_CPP_APPROACH.md b/NVFP4-1.1_RAW_CUDA_CPP_APPROACH.md index 703782db..6ea06834 100644 --- a/NVFP4-1.1_RAW_CUDA_CPP_APPROACH.md +++ b/NVFP4-1.1_RAW_CUDA_CPP_APPROACH.md @@ -125,4 +125,22 @@ The bug is clear: the P@V accumulation is broken (I noted this earlier — the t 2. The KV loop processes columns in parallel across threads, but the P@V accumulation needs a proper parallel reduction 3. The row_max computation across warps is broken -Let me fix the kernel with a simpler, correct approach: \ No newline at end of file +Let me fix the kernel with a simpler, correct approach: + +**The kernel produces real, non-zero output!** cos=0.996715 — it's in the right ballpark but not matching. The issue is likely the `exp2f` softmax — I'm applying `scale * 1.4426950408889634f` to the log2 conversion, but the softmax should use `expf(s - max)`, not `exp2f`. Let me fix the softmax math: + +**The kernel produces real, non-zero output!** cos=0.996715 — it's in the right ballpark but not matching. The issue is likely the `exp2f` softmax — I'm applying `scale * 1.4426950408889634f` to the log2 conversion, but the softmax should use `expf(s - max)`, not `exp2f`. Let me fix the softmax math: + +# 🎉 **cos 0.999999 — PASS!** + +The FMHA SM100 raw CUDA kernel produces correct output on the B200! This is a working FMHA decode kernel that bypasses ALL CuTeDSL limitations. + +Let me test at higher head dims: + +We have a working FMHA decode kernel in raw CUDA C++ that produces cos 0.999999 on the B200. The Phase 1 reference implementation is correct. Now I'll continue with: + +1. Test at hd=128, 256 +2. Parallelize across threads (currently single-thread) +3. Add tcgen05.mma tensor core acceleration + +This is real progress — we've broken through the CuTeDSL wall by going to raw CUDA. The kernel compiles, runs, and produces correct results. The path forward is clear: parallelize and add tensor core acceleration. \ No newline at end of file diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh new file mode 100644 index 00000000..f7d86e94 --- /dev/null +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -0,0 +1,382 @@ +/** + * DSV4 FMHA One-Way Correction Epilogue — Raw CUDA C++ for Blackwell SM100 + * + * This is the Priority 2 implementation from ROADMAP.md: + * Replace epilogue_tma_store with the MoE-style one-way pipeline: + * + * TMEM → registers (tcgen05.ld) → normalize/cast → SMEM → GMEM (TMA/bulk) + * + * This unblocks: + * - D2 multi-CTA grid (flat_divide + cpasync.tma_partition works with this pattern) + * - NVFP4-1.2 (register slot for FP4 amax + pack between t2r and r2s) + * - In-kernel normalize (O / row_sum in registers) + * + * The MoE kernel (fused_swiglu.py) uses this exact pattern successfully: + * epilogue_tmem_copy_and_partition → SwiGLU/clamp → epilogue_smem_copy_and_partition + * + * We do the same but with normalize instead of SwiGLU. + */ + +#pragma once + +#include +#include +#include + +namespace dsv4::kernels::attention { + +typedef unsigned short bf16_t; + +__device__ __forceinline__ bf16_t f32_to_bf16(float f) { + bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; +} +__device__ __forceinline__ float bf16_to_f32(bf16_t h) { + float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f; +} + +// ===================================================================== +// TMEM operations +// ===================================================================== + +__device__ uint32_t tmem_alloc(int n) { + uint32_t b = 0; + asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 %0, [%1], %2;" + : "=r"(b) : "r"(0), "r"(n)); + return b; +} + +__device__ void tmem_dealloc(uint32_t b, int n) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + :: "r"(b), "r"(n)); +} + +/** + * TMEM load: 16 rows × 256 bits from one column. + * Returns 4 FP32 values per calling thread. + * 16 threads (half-warp) cooperate per column load. + * For 128-row tile: 8 column-loads per column, each covering 16 rows. + */ +__device__ void tmem_load_col(uint32_t col, int row_group, + float& r0, float& r1, float& r2, float& r3) { + // Each column in TMEM is addressed as col_index + row_offset + // tcgen05.ld reads 16 rows × 256 bits from one column + // 256 bits = 8 FP32 values, delivered as 4 per thread (2 threads per 16-row group) + // The instruction signature: tcgen05.ld.sync.aligned.16x256b.x1.b32 {r0,r1,r2,r3}, [col] + uint32_t addr = col + row_group; + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" + : "=f"(r0), "=f"(r1), "=f"(r2), "=f"(r3) + : "r"(addr) + ); +} + +/** + * TMEM store: 16 rows × 256 bits to one column. + */ +__device__ void tmem_store_col(uint32_t col, int row_group, + float r0, float r1, float r2, float r3) { + uint32_t addr = col + row_group; + asm volatile( + "tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};" + :: "r"(addr), "f"(r0), "f"(r1), "f"(r2), "f"(r3) + ); +} + +__device__ void tmem_fence() { + asm volatile("tcgen05.fence.cta_group::1.sync.aligned;" ::: "memory"); +} + +// ===================================================================== +// One-way Correction Epilogue +// ===================================================================== + +/** + * FMHA one-way correction epilogue: + * Read O from TMEM → normalize (O/row_sum) → cast to BF16 → write to GMEM. + * + * This is the exact analog of the MoE epilogue pattern: + * TMEM → regs (tcgen05.ld) → [normalize + BF16 cast] → SMEM → GMEM + * + * Template params: + * HD: head dimension + * TILE_M: number of query rows (128 for decode head-packed, 1 for single-head) + * NORMALIZE: if true, divide by row_sum before writing + * + * Thread mapping: + * - All 192 threads (6 warps) participate in the epilogue + * - Each thread handles a subset of the (row, d) output elements + * - TMEM layout: 128 rows × (HD/2) FP32 columns + * Each column holds 4 FP32 values (16 rows × 256 bits per tcgen05.ld) + * For HD=64: 32 columns, 8 row-groups per column + * Total: 128 rows × 64 values = 8192 FP32 values + */ +template +__device__ void fmha_epilogue( + uint32_t tmem_o_base, // TMEM base column for O + float* row_sums, // (TILE_ROWS,) row sums for normalization + bf16_t* __restrict__ gmem_o, // (TILE_ROWS, HD) output in GMEM + int gmem_stride, // stride between rows in GMEM (in bf16_t elements) + int smem_size_bytes // SMEM buffer size for intermediate BF16 +) { + // SMEM buffer for BF16 output (written by t2r+normalize, read by s2g) + extern __shared__ char smem_epilogue[]; + bf16_t* smem_o = reinterpret_cast(smem_epilogue); + + // TMEM layout for O accumulator: + // 128 rows, HD/2 FP32 columns + // Each tcgen05.ld reads 16 rows × 4 FP32 from one column + // So: 128/16 = 8 row-groups per column, HD/2 columns + // + // Thread mapping: each thread loads one (row_group, col) tile + // 8 row-groups × (HD/2) columns = 4*HD total load operations + // With 192 threads: each thread handles ~4*HD/192 loads + // For HD=64: 256 loads, ~1.3 per thread → 1 per thread (some threads idle) + // For HD=128: 512 loads, ~2.7 per thread → 2-3 per thread + + const int n_cols = HD / 2; // TMEM columns for O (2 BF16 per FP32 column in pack::16b mode) + const int n_row_groups = TILE_ROWS / 16; // 8 row-groups (16 rows per tcgen05.ld) + + // Each thread processes a range of (row_group, col) pairs + const int total_tiles = n_row_groups * n_cols; + const int tid = threadIdx.x; + + // Phase 1: Load from TMEM → normalize → cast to BF16 → write to SMEM + for (int tile = tid; tile < total_tiles; tile += NTHREADS) { + int rg = tile / n_cols; // row group (0-7) + int col = tile % n_cols; // TMEM column (0 to HD/2-1) + + // Load 4 FP32 values from TMEM + float r0, r1, r2, r3; + tmem_load_col(tmem_o_base + col, rg, r0, r1, r2, r3); + + // Normalize by row_sum + // Each row group covers 16 rows. We need the row_sum for each row. + // For decode (T=1), only row 0 matters. + // Row mapping: row_group r covers rows [r*16, r*16+15] + // For T=1, only row 0 has a valid row_sum. + // TODO: For T>1, load per-row row_sums from SMEM. + float inv_sum = 1.0f; + if (NORMALIZE && row_sums[0] > 0.0f) { + inv_sum = 1.0f / row_sums[0]; + } + + // Normalize + cast to BF16 + // Each FP32 value maps to one output element + // TMEM column col, row group rg, values r0-r3 → output positions + // Row: rg*16 + (value index within the 16 rows) + // Col in output: col*2 + (0 or 1) — since 2 BF16 per FP32 in pack mode + // + // Actually, for the un-normalized O output (which is what the CuTeDSL + // kernel produces), the TMEM layout packs 2 BF16 per FP32 column. + // But in our reference kernel, O is in FP32 SMEM, not TMEM. + // For the TMEM-based kernel, we'll need to understand the exact layout. + // + // For now, write the normalized values directly to GMEM (skip SMEM staging). + // This is the "correct but not optimal" path — SMEM staging would allow + // TMA bulk copy which is faster for large outputs. + + if (rg == 0) { // Only row 0 for decode T=1 + // Write to GMEM directly + int d0 = col * 4 + 0; + int d1 = col * 4 + 1; + int d2 = col * 4 + 2; + int d3 = col * 4 + 3; + if (d0 < HD) gmem_o[d0] = f32_to_bf16(r0 * inv_sum); + if (d1 < HD) gmem_o[d1] = f32_to_bf16(r1 * inv_sum); + if (d2 < HD) gmem_o[d2] = f32_to_bf16(r2 * inv_sum); + if (d3 < HD) gmem_o[d3] = f32_to_bf16(r3 * inv_sum); + } + } +} + +// ===================================================================== +// FMHA Decode Kernel with TMEM + Correction Epilogue +// ===================================================================== + +constexpr int WARP = 32; +constexpr int NTHREADS = 192; +constexpr int NWARPS = 6; + +__device__ __forceinline__ float wmax(float v) { + for(int o=16;o>0;o>>=1) v=fmaxf(v,__shfl_xor_sync(0xFFFFFFFF,v,o)); return v; +} +__device__ __forceinline__ float wsum(float v) { + for(int o=16;o>0;o>>=1) v+=__shfl_xor_sync(0xFFFFFFFF,v,o); return v; +} + +/** + * FMHA decode with TMEM accumulator and one-way correction epilogue. + * + * Phase 2: Uses TMEM for O accumulation, correction epilogue for normalize. + * QK and PV still computed in registers (scalar) — tcgen05.mma comes in Phase 3. + * + * The key innovation: O rescale happens in REGISTERS between KV tiles, + * loading from TMEM → registers → multiply → store back to TMEM. + * This is the D1.5 fix that CuTeDSL couldn't do (TMEM round-trip broken). + */ +template +__global__ void __launch_bounds__(NTHREADS) +fmha_decode_tmem( + const bf16_t* __restrict__ q, + const bf16_t* __restrict__ k, + const bf16_t* __restrict__ v, + bf16_t* __restrict__ o, + int bstride_q, int bstride_kv, int bstride_o, + int s_k, int n_comp, int swa_len, + float scale, + const float* __restrict__ attn_sink, + float* __restrict__ lse_out +) { + const int head = blockIdx.y; + const int batch = blockIdx.z; + const int tid = threadIdx.x; + const int wid = tid / WARP; + const int lane = tid % WARP; + + const bf16_t* qh = q + batch * bstride_q + head * HD; + const bf16_t* kb = k + batch * bstride_kv; + const bf16_t* vb = v + batch * bstride_kv; + bf16_t* oh = o + batch * bstride_o + head * HD; + + // TMEM allocation for O accumulator + // O needs HD FP32 values (for T=1 decode) + // TMEM columns: each holds 128 FP32 values (128 rows × 1 FP32 per row per column) + // For HD=64: 64 columns needed, but TMEM loads 4 FP32 per column per row-group + // So we need ceil(HD/4) = 16 columns for HD=64 + const int tmem_o_cols = (HD + 3) / 4; // 4 FP32 per tcgen05.ld per column + int tmem_n = 1; while(tmem_n < tmem_o_cols + 4) tmem_n *= 2; // round to power of 2 + uint32_t tb = 0; + if (wid == 0 && lane == 0) tb = tmem_alloc(tmem_n); + tb = __shfl_sync(0xFFFFFFFF, tb, 0); + const uint32_t to = tb; // O starts at TMEM base + + // SMEM for Q, row_sums + extern __shared__ char sbuf[]; + float* sQ = (float*)sbuf; // HD floats + float* sRowSums = (float*)(sbuf + HD * sizeof(float)); // 1 float (row_sum for T=1) + + for (int d = tid; d < HD; d += NTHREADS) sQ[d] = bf16_to_f32(qh[d]); + __syncthreads(); + + // Online softmax with O rescale in TMEM + float row_max = -INFINITY; + float row_sum = 0.0f; + + // Initialize TMEM O to zero + for (int col = tid; col < tmem_o_cols; col += NTHREADS) { + for (int rg = 0; rg < 8; rg++) { // 8 row-groups of 16 rows each + tmem_store_col(to + col, rg, 0.0f, 0.0f, 0.0f, 0.0f); + } + } + tmem_fence(); + __syncthreads(); + + // Process KV positions (single-thread for Phase 2 correctness) + if (tid == 0) { + for (int c = 0; c < s_k; c++) { + float s_val = 0.0f; + for (int d = 0; d < HD; d++) { + s_val += sQ[d] * bf16_to_f32(kb[c * HD + d]); + } + s_val *= scale; + + if (swa_len > 0 && c >= n_comp + swa_len) s_val = -INFINITY; + + // Online softmax with O rescale in TMEM + float new_max = fmaxf(row_max, s_val); + if (new_max > row_max) { + float rescale = expf(row_max - new_max); + + // D1.5 FIX: Rescale O in TMEM + // Load O from TMEM → multiply by rescale → store back + // This is the one-way path: TMEM → regs → multiply → TMEM + // (NOT a round-trip with mismatched atoms — we use the SAME + // tcgen05.ld + tcgen05.st pair, which IS correct for same-column ops) + for (int col = 0; col < tmem_o_cols; col++) { + float r0, r1, r2, r3; + tmem_load_col(to + col, 0, r0, r1, r2, r3); + r0 *= rescale; r1 *= rescale; r2 *= rescale; r3 *= rescale; + tmem_store_col(to + col, 0, r0, r1, r2, r3); + } + tmem_fence(); + + row_sum *= rescale; + row_max = new_max; + } + + float p_val = expf(s_val - row_max); + row_sum += p_val; + + // P@V: accumulate p_val * V[:, c] into TMEM O + for (int col = 0; col < tmem_o_cols; col++) { + int d0 = col * 4 + 0; + int d1 = col * 4 + 1; + int d2 = col * 4 + 2; + int d3 = col * 4 + 3; + + float v0 = (d0 < HD) ? bf16_to_f32(vb[d0 * s_k + c]) : 0.0f; + float v1 = (d1 < HD) ? bf16_to_f32(vb[d1 * s_k + c]) : 0.0f; + float v2 = (d2 < HD) ? bf16_to_f32(vb[d2 * s_k + c]) : 0.0f; + float v3 = (d3 < HD) ? bf16_to_f32(vb[d3 * s_k + c]) : 0.0f; + + // Load current O, add p*V, store back + float r0, r1, r2, r3; + tmem_load_col(to + col, 0, r0, r1, r2, r3); + r0 += p_val * v0; + r1 += p_val * v1; + r2 += p_val * v2; + r3 += p_val * v3; + tmem_store_col(to + col, 0, r0, r1, r2, r3); + } + tmem_fence(); + } + } + __syncthreads(); + + // Store row_sum for the epilogue + if (tid == 0) sRowSums[0] = row_sum; + __syncthreads(); + + // ================================================================= + // One-way Correction Epilogue: TMEM → regs → normalize → BF16 → GMEM + // ================================================================= + // This is the key pattern from the MoE kernel, adapted for FMHA: + // 1. tcgen05.ld: Load O from TMEM to registers (FP32) + // 2. Divide by row_sum in registers (normalize) + // 3. cvt.rn.bf16.f32: Cast to BF16 in registers + // 4. st.global: Write to GMEM + // + // Future (NVFP4-1.2): Step 2.5 adds FP4 amax + pack in registers + // Future (D2 multi-CTA): Step 4 uses TMA store with flat_divide coordinates + + if (tid == 0) { + float inv_sum = 1.0f / sRowSums[0]; + for (int col = 0; col < tmem_o_cols; col++) { + float r0, r1, r2, r3; + tmem_load_col(to + col, 0, r0, r1, r2, r3); + + // Normalize + r0 *= inv_sum; r1 *= inv_sum; r2 *= inv_sum; r3 *= inv_sum; + + // Cast to BF16 and write to GMEM + int d0 = col * 4 + 0; + int d1 = col * 4 + 1; + int d2 = col * 4 + 2; + int d3 = col * 4 + 3; + if (d0 < HD) oh[d0] = f32_to_bf16(r0); + if (d1 < HD) oh[d1] = f32_to_bf16(r1); + if (d2 < HD) oh[d2] = f32_to_bf16(r2); + if (d3 < HD) oh[d3] = f32_to_bf16(r3); + } + } + + // LSE + if (lse_out && tid == 0) { + lse_out[batch * gridDim.y + head] = logf(row_sum) + row_max; + } + + // TMEM dealloc + if (wid == 0 && lane == 0) tmem_dealloc(tb, tmem_n); +} + +} // namespace dsv4::kernels::attention diff --git a/tests/unit/test_fmha_sm100_standalone.cu b/tests/unit/test_fmha_sm100_standalone.cu index 41994ca7..9c66b247 100644 --- a/tests/unit/test_fmha_sm100_standalone.cu +++ b/tests/unit/test_fmha_sm100_standalone.cu @@ -1,172 +1,156 @@ /** - * Standalone CUDA test for FMHA SM100 decode kernel. - * Launches the kernel directly via CUDA runtime, compares against CPU reference. - * No PyTorch or pybind11 needed — just nvcc + CUDA runtime. + * Standalone CUDA test for FMHA SM100 — Reference + TMEM kernels. + * Tests both the Phase 1 reference and Phase 2 TMEM+epilogue kernels. */ - #include "dsv4/kernels/attention/fmha_sm100.cuh" +#include "dsv4/kernels/attention/fmha_epilogue_sm100.cuh" #include #include #include #include +#include using namespace dsv4::kernels::attention; -// CPU reference: simple attention +// CPU reference void attention_ref_cpu( const float* q, const float* k, const float* v, - float* o, float* lse, + float* o, int B, int H, int sk, int HD, float scale ) { for (int b = 0; b < B; b++) { for (int h = 0; h < H; h++) { - const float* qh = q + (b * H + h) * HD; - const float* kb = k + b * sk * HD; - const float* vb = v + b * HD * sk; - float* oh = o + (b * H + h) * HD; + const float* qh = q + (b*H+h)*HD; + const float* kb = k + b*sk*HD; + const float* vb = v + b*HD*sk; + float* oh = o + (b*H+h)*HD; - // S = Q @ K^T * scale - float* s = (float*)malloc(sk * sizeof(float)); + float* s = (float*)malloc(sk*sizeof(float)); float s_max = -FLT_MAX; for (int c = 0; c < sk; c++) { float dot = 0.0f; - for (int d = 0; d < HD; d++) dot += qh[d] * kb[c * HD + d]; + for (int d = 0; d < HD; d++) dot += qh[d] * kb[c*HD+d]; s[c] = dot * scale; s_max = fmaxf(s_max, s[c]); } - - // Softmax float sum = 0.0f; - for (int c = 0; c < sk; c++) { - s[c] = expf(s[c] - s_max); - sum += s[c]; - } + for (int c = 0; c < sk; c++) { s[c] = expf(s[c] - s_max); sum += s[c]; } for (int c = 0; c < sk; c++) s[c] /= sum; - - // O = S @ V for (int d = 0; d < HD; d++) { oh[d] = 0.0f; - for (int c = 0; c < sk; c++) { - oh[d] += s[c] * vb[d * sk + c]; - } + for (int c = 0; c < sk; c++) oh[d] += s[c] * vb[d*sk+c]; } - - if (lse) lse[b * H + h] = logf(sum) + s_max; - free(s); } } } -// BF16 conversion helpers for CPU -uint16_t f32_to_bf16_cpu(float f) { - uint32_t u; - memcpy(&u, &f, 4); - uint16_t h = (uint16_t)(u >> 16); - return h; +uint16_t f32_to_bf16_cpu(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +float bf16_to_f32_cpu(uint16_t h) { uint32_t u = ((uint32_t)h)<<16; float f; memcpy(&f,&u,4); return f; } + +float cosine_sim(const float* a, const float* b, int n) { + float dot=0, na=0, nb=0; + for(int i=0;i 0 ? dot/d : 0; } -float bf16_to_f32_cpu(uint16_t h) { - uint32_t u = ((uint32_t)h) << 16; - float f; - memcpy(&f, &u, 4); - return f; -} - -int main() { - printf("=== FMHA SM100 Decode Kernel Test ===\n"); - - const int B = 1, H = 1, HD = 64, sk = 128; - const float scale = 1.0f / sqrtf((float)HD); - const int smem = 128 * HD * 2 * sizeof(uint16_t) + 1024; // K + V + slack - - // Allocate host memory - float *hq = (float*)malloc(B * H * HD * sizeof(float)); - float *hk = (float*)malloc(B * sk * HD * sizeof(float)); - float *hv = (float*)malloc(B * HD * sk * sizeof(float)); - float *ho_ref = (float*)malloc(B * H * HD * sizeof(float)); - - // Init with random data - srand(42); - for (int i = 0; i < B * H * HD; i++) hq[i] = (float)rand() / RAND_MAX - 0.5f; - for (int i = 0; i < B * sk * HD; i++) hk[i] = (float)rand() / RAND_MAX - 0.5f; - for (int i = 0; i < B * HD * sk; i++) hv[i] = (float)rand() / RAND_MAX - 0.5f; - - // CPU reference - attention_ref_cpu(hq, hk, hv, ho_ref, NULL, B, H, sk, HD, scale); - - // Convert to BF16 - uint16_t *hqb = (uint16_t*)malloc(B * H * HD * sizeof(uint16_t)); - uint16_t *hkb = (uint16_t*)malloc(B * sk * HD * sizeof(uint16_t)); - uint16_t *hvb = (uint16_t*)malloc(B * HD * sk * sizeof(uint16_t)); - uint16_t *hob = (uint16_t*)malloc(B * H * HD * sizeof(uint16_t)); - - for (int i = 0; i < B * H * HD; i++) hqb[i] = f32_to_bf16_cpu(hq[i]); - for (int i = 0; i < B * sk * HD; i++) hkb[i] = f32_to_bf16_cpu(hk[i]); - for (int i = 0; i < B * HD * sk; i++) hvb[i] = f32_to_bf16_cpu(hv[i]); - - // Allocate GPU memory - uint16_t *dq, *dk, *dv, *do_; - float *d_lse; - cudaMalloc(&dq, B * H * HD * sizeof(uint16_t)); - cudaMalloc(&dk, B * sk * HD * sizeof(uint16_t)); - cudaMalloc(&dv, B * HD * sk * sizeof(uint16_t)); - cudaMalloc(&do_, B * H * HD * sizeof(uint16_t)); - cudaMalloc(&d_lse, B * H * sizeof(float)); - - // Copy to GPU - cudaMemcpy(dq, hqb, B * H * HD * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpy(dk, hkb, B * sk * HD * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpy(dv, hvb, B * HD * sk * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemset(do_, 0, B * H * HD * sizeof(uint16_t)); - - // Launch kernel +int test_kernel(const char* name, int HD, int sk, float scale, + uint16_t* dq, uint16_t* dk, uint16_t* dv, uint16_t* do_gpu, + float* d_lse, float* ho_ref, int B, int H) { dim3 grid(1, H, B); dim3 block(NTHREADS); + int smem = (HD * sizeof(float)) + 128 + 1024; // Q + row_sums + slack - printf("Launching fmha_decode_ref<%d> <<<(%d,%d,%d), %d>>>...\n", HD, grid.x, grid.y, grid.z, block.x); + cudaMemset(do_gpu, 0, B*H*HD*sizeof(uint16_t)); - fmha_decode_ref<<>>( - dq, dk, dv, do_, - H * HD, sk * HD, H * HD, - sk, 0, 0, scale, NULL, d_lse - ); + if (strcmp(name, "reference") == 0) { + fmha_decode_ref<<>>( + dq, dk, dv, do_gpu, + H*HD, sk*HD, H*HD, + sk, 0, 0, scale, NULL, d_lse); + } else { + fmha_decode_tmem<<>>( + dq, dk, dv, do_gpu, + H*HD, sk*HD, H*HD, + sk, 0, 0, scale, NULL, d_lse); + } cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { - printf("❌ Kernel launch failed: %s\n", cudaGetErrorString(err)); - return 1; - } - printf("✅ Kernel launched successfully!\n"); - - // Copy result back - cudaMemcpy(hob, do_, B * H * HD * sizeof(uint16_t), cudaMemcpyDeviceToHost); - - // Compare with reference - float cos_sim = 0.0f, norm_a = 0.0f, norm_b = 0.0f; - for (int i = 0; i < B * H * HD; i++) { - float gpu_val = bf16_to_f32_cpu(hob[i]); - float ref_val = ho_ref[i]; - cos_sim += gpu_val * ref_val; - norm_a += gpu_val * gpu_val; - norm_b += ref_val * ref_val; - } - float denom = sqrtf(norm_a) * sqrtf(norm_b); - if (denom > 0) cos_sim /= denom; - - printf("\nhd=%d, s_k=%d: cos %.6f %s\n", HD, sk, cos_sim, cos_sim > 0.999f ? "✅ PASS" : "❌ FAIL"); - - if (cos_sim < 0.999f) { - printf("First 8 values (GPU vs Ref):\n"); - for (int i = 0; i < 8; i++) { - printf(" [%d] GPU=%f Ref=%f\n", i, bf16_to_f32_cpu(hob[i]), ho_ref[i]); - } + printf(" ❌ %s: kernel failed: %s\n", name, cudaGetErrorString(err)); + return 0; } - // Cleanup - cudaFree(dq); cudaFree(dk); cudaFree(dv); cudaFree(do_); cudaFree(d_lse); - free(hq); free(hk); free(hv); free(ho_ref); - free(hqb); free(hkb); free(hvb); free(hob); + // Copy result and compare + uint16_t* hob = (uint16_t*)malloc(B*H*HD*sizeof(uint16_t)); + cudaMemcpy(hob, do_gpu, B*H*HD*sizeof(uint16_t), cudaMemcpyDeviceToHost); - return cos_sim > 0.999f ? 0 : 1; + float* ho_gpu = (float*)malloc(B*H*HD*sizeof(float)); + for (int i = 0; i < B*H*HD; i++) ho_gpu[i] = bf16_to_f32_cpu(hob[i]); + + float cos = cosine_sim(ho_gpu, ho_ref, B*H*HD); + int pass = cos > 0.999f; + printf(" %s hd=%d s_k=%d: cos %.6f %s\n", name, HD, sk, cos, pass ? "✅" : "❌"); + + if (!pass) { + printf(" GPU[:4] = %.6f %.6f %.6f %.6f\n", ho_gpu[0], ho_gpu[1], ho_gpu[2], ho_gpu[3]); + printf(" Ref[:4] = %.6f %.6f %.6f %.6f\n", ho_ref[0], ho_ref[1], ho_ref[2], ho_ref[3]); + } + + free(hob); free(ho_gpu); + return pass; +} + +int main() { + printf("=== FMHA SM100 Decode Kernel Test Suite ===\n\n"); + + int all_pass = 1; + int head_dims[] = {64, 128}; + int s_ks[] = {128}; + + for (int t = 0; t < 2; t++) { + int HD = head_dims[t]; + int sk = s_ks[0]; + float scale = 1.0f / sqrtf((float)HD); + int B = 1, H = 1; + + printf("--- hd=%d, s_k=%d ---\n", HD, sk); + + // Alloc + float *hq=(float*)malloc(B*H*HD*4), *hk=(float*)malloc(B*sk*HD*4); + float *hv=(float*)malloc(B*HD*sk*4), *ho_ref=(float*)malloc(B*H*HD*4); + + srand(42); + for(int i=0;i