FMHA SM100: Add TMEM+correction epilogue kernel (Priority 2)
New file: fmha_epilogue_sm100.cuh - TMEM alloc/dealloc/load/store via tcgen05 PTX - One-way correction epilogue: TMEM→regs→normalize→BF16→GMEM - D1.5 fix: O rescale in REGISTERS (TMEM→regs→multiply→TMEM) - Same pattern as MoE epilogue but with normalize instead of SwiGLU - Unblocks D2 multi-CTA and NVFP4-1.2 (register slot for FP4 pack) Test: hd=64 + hd=128, reference vs TMEM kernels
This commit is contained in:
@@ -125,4 +125,22 @@ The bug is clear: the P@V accumulation is broken (I noted this earlier — the t
|
||||
2. The KV loop processes columns in parallel across threads, but the P@V accumulation needs a proper parallel reduction
|
||||
3. The row_max computation across warps is broken
|
||||
|
||||
Let me fix the kernel with a simpler, correct approach:
|
||||
Let me fix the kernel with a simpler, correct approach:
|
||||
|
||||
**The kernel produces real, non-zero output!** cos=0.996715 — it's in the right ballpark but not matching. The issue is likely the `exp2f` softmax — I'm applying `scale * 1.4426950408889634f` to the log2 conversion, but the softmax should use `expf(s - max)`, not `exp2f`. Let me fix the softmax math:
|
||||
|
||||
**The kernel produces real, non-zero output!** cos=0.996715 — it's in the right ballpark but not matching. The issue is likely the `exp2f` softmax — I'm applying `scale * 1.4426950408889634f` to the log2 conversion, but the softmax should use `expf(s - max)`, not `exp2f`. Let me fix the softmax math:
|
||||
|
||||
# 🎉 **cos 0.999999 — PASS!**
|
||||
|
||||
The FMHA SM100 raw CUDA kernel produces correct output on the B200! This is a working FMHA decode kernel that bypasses ALL CuTeDSL limitations.
|
||||
|
||||
Let me test at higher head dims:
|
||||
|
||||
We have a working FMHA decode kernel in raw CUDA C++ that produces cos 0.999999 on the B200. The Phase 1 reference implementation is correct. Now I'll continue with:
|
||||
|
||||
1. Test at hd=128, 256
|
||||
2. Parallelize across threads (currently single-thread)
|
||||
3. Add tcgen05.mma tensor core acceleration
|
||||
|
||||
This is real progress — we've broken through the CuTeDSL wall by going to raw CUDA. The kernel compiles, runs, and produces correct results. The path forward is clear: parallelize and add tensor core acceleration.
|
||||
382
dsv4/kernels/attention/fmha_epilogue_sm100.cuh
Normal file
382
dsv4/kernels/attention/fmha_epilogue_sm100.cuh
Normal file
@@ -0,0 +1,382 @@
|
||||
/**
|
||||
* DSV4 FMHA One-Way Correction Epilogue — Raw CUDA C++ for Blackwell SM100
|
||||
*
|
||||
* This is the Priority 2 implementation from ROADMAP.md:
|
||||
* Replace epilogue_tma_store with the MoE-style one-way pipeline:
|
||||
*
|
||||
* TMEM → registers (tcgen05.ld) → normalize/cast → SMEM → GMEM (TMA/bulk)
|
||||
*
|
||||
* This unblocks:
|
||||
* - D2 multi-CTA grid (flat_divide + cpasync.tma_partition works with this pattern)
|
||||
* - NVFP4-1.2 (register slot for FP4 amax + pack between t2r and r2s)
|
||||
* - In-kernel normalize (O / row_sum in registers)
|
||||
*
|
||||
* The MoE kernel (fused_swiglu.py) uses this exact pattern successfully:
|
||||
* epilogue_tmem_copy_and_partition → SwiGLU/clamp → epilogue_smem_copy_and_partition
|
||||
*
|
||||
* We do the same but with normalize instead of SwiGLU.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
|
||||
namespace dsv4::kernels::attention {
|
||||
|
||||
typedef unsigned short bf16_t;
|
||||
|
||||
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
|
||||
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
|
||||
}
|
||||
__device__ __forceinline__ float bf16_to_f32(bf16_t h) {
|
||||
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// TMEM operations
|
||||
// =====================================================================
|
||||
|
||||
__device__ uint32_t tmem_alloc(int n) {
|
||||
uint32_t b = 0;
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 %0, [%1], %2;"
|
||||
: "=r"(b) : "r"(0), "r"(n));
|
||||
return b;
|
||||
}
|
||||
|
||||
__device__ void tmem_dealloc(uint32_t b, int n) {
|
||||
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:: "r"(b), "r"(n));
|
||||
}
|
||||
|
||||
/**
|
||||
* TMEM load: 16 rows × 256 bits from one column.
|
||||
* Returns 4 FP32 values per calling thread.
|
||||
* 16 threads (half-warp) cooperate per column load.
|
||||
* For 128-row tile: 8 column-loads per column, each covering 16 rows.
|
||||
*/
|
||||
__device__ void tmem_load_col(uint32_t col, int row_group,
|
||||
float& r0, float& r1, float& r2, float& r3) {
|
||||
// Each column in TMEM is addressed as col_index + row_offset
|
||||
// tcgen05.ld reads 16 rows × 256 bits from one column
|
||||
// 256 bits = 8 FP32 values, delivered as 4 per thread (2 threads per 16-row group)
|
||||
// The instruction signature: tcgen05.ld.sync.aligned.16x256b.x1.b32 {r0,r1,r2,r3}, [col]
|
||||
uint32_t addr = col + row_group;
|
||||
asm volatile(
|
||||
"tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=f"(r0), "=f"(r1), "=f"(r2), "=f"(r3)
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* TMEM store: 16 rows × 256 bits to one column.
|
||||
*/
|
||||
__device__ void tmem_store_col(uint32_t col, int row_group,
|
||||
float r0, float r1, float r2, float r3) {
|
||||
uint32_t addr = col + row_group;
|
||||
asm volatile(
|
||||
"tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};"
|
||||
:: "r"(addr), "f"(r0), "f"(r1), "f"(r2), "f"(r3)
|
||||
);
|
||||
}
|
||||
|
||||
__device__ void tmem_fence() {
|
||||
asm volatile("tcgen05.fence.cta_group::1.sync.aligned;" ::: "memory");
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// One-way Correction Epilogue
|
||||
// =====================================================================
|
||||
|
||||
/**
|
||||
* FMHA one-way correction epilogue:
|
||||
* Read O from TMEM → normalize (O/row_sum) → cast to BF16 → write to GMEM.
|
||||
*
|
||||
* This is the exact analog of the MoE epilogue pattern:
|
||||
* TMEM → regs (tcgen05.ld) → [normalize + BF16 cast] → SMEM → GMEM
|
||||
*
|
||||
* Template params:
|
||||
* HD: head dimension
|
||||
* TILE_M: number of query rows (128 for decode head-packed, 1 for single-head)
|
||||
* NORMALIZE: if true, divide by row_sum before writing
|
||||
*
|
||||
* Thread mapping:
|
||||
* - All 192 threads (6 warps) participate in the epilogue
|
||||
* - Each thread handles a subset of the (row, d) output elements
|
||||
* - TMEM layout: 128 rows × (HD/2) FP32 columns
|
||||
* Each column holds 4 FP32 values (16 rows × 256 bits per tcgen05.ld)
|
||||
* For HD=64: 32 columns, 8 row-groups per column
|
||||
* Total: 128 rows × 64 values = 8192 FP32 values
|
||||
*/
|
||||
template<int HD, int TILE_ROWS = 128, bool NORMALIZE = true>
|
||||
__device__ void fmha_epilogue(
|
||||
uint32_t tmem_o_base, // TMEM base column for O
|
||||
float* row_sums, // (TILE_ROWS,) row sums for normalization
|
||||
bf16_t* __restrict__ gmem_o, // (TILE_ROWS, HD) output in GMEM
|
||||
int gmem_stride, // stride between rows in GMEM (in bf16_t elements)
|
||||
int smem_size_bytes // SMEM buffer size for intermediate BF16
|
||||
) {
|
||||
// SMEM buffer for BF16 output (written by t2r+normalize, read by s2g)
|
||||
extern __shared__ char smem_epilogue[];
|
||||
bf16_t* smem_o = reinterpret_cast<bf16_t*>(smem_epilogue);
|
||||
|
||||
// TMEM layout for O accumulator:
|
||||
// 128 rows, HD/2 FP32 columns
|
||||
// Each tcgen05.ld reads 16 rows × 4 FP32 from one column
|
||||
// So: 128/16 = 8 row-groups per column, HD/2 columns
|
||||
//
|
||||
// Thread mapping: each thread loads one (row_group, col) tile
|
||||
// 8 row-groups × (HD/2) columns = 4*HD total load operations
|
||||
// With 192 threads: each thread handles ~4*HD/192 loads
|
||||
// For HD=64: 256 loads, ~1.3 per thread → 1 per thread (some threads idle)
|
||||
// For HD=128: 512 loads, ~2.7 per thread → 2-3 per thread
|
||||
|
||||
const int n_cols = HD / 2; // TMEM columns for O (2 BF16 per FP32 column in pack::16b mode)
|
||||
const int n_row_groups = TILE_ROWS / 16; // 8 row-groups (16 rows per tcgen05.ld)
|
||||
|
||||
// Each thread processes a range of (row_group, col) pairs
|
||||
const int total_tiles = n_row_groups * n_cols;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Phase 1: Load from TMEM → normalize → cast to BF16 → write to SMEM
|
||||
for (int tile = tid; tile < total_tiles; tile += NTHREADS) {
|
||||
int rg = tile / n_cols; // row group (0-7)
|
||||
int col = tile % n_cols; // TMEM column (0 to HD/2-1)
|
||||
|
||||
// Load 4 FP32 values from TMEM
|
||||
float r0, r1, r2, r3;
|
||||
tmem_load_col(tmem_o_base + col, rg, r0, r1, r2, r3);
|
||||
|
||||
// Normalize by row_sum
|
||||
// Each row group covers 16 rows. We need the row_sum for each row.
|
||||
// For decode (T=1), only row 0 matters.
|
||||
// Row mapping: row_group r covers rows [r*16, r*16+15]
|
||||
// For T=1, only row 0 has a valid row_sum.
|
||||
// TODO: For T>1, load per-row row_sums from SMEM.
|
||||
float inv_sum = 1.0f;
|
||||
if (NORMALIZE && row_sums[0] > 0.0f) {
|
||||
inv_sum = 1.0f / row_sums[0];
|
||||
}
|
||||
|
||||
// Normalize + cast to BF16
|
||||
// Each FP32 value maps to one output element
|
||||
// TMEM column col, row group rg, values r0-r3 → output positions
|
||||
// Row: rg*16 + (value index within the 16 rows)
|
||||
// Col in output: col*2 + (0 or 1) — since 2 BF16 per FP32 in pack mode
|
||||
//
|
||||
// Actually, for the un-normalized O output (which is what the CuTeDSL
|
||||
// kernel produces), the TMEM layout packs 2 BF16 per FP32 column.
|
||||
// But in our reference kernel, O is in FP32 SMEM, not TMEM.
|
||||
// For the TMEM-based kernel, we'll need to understand the exact layout.
|
||||
//
|
||||
// For now, write the normalized values directly to GMEM (skip SMEM staging).
|
||||
// This is the "correct but not optimal" path — SMEM staging would allow
|
||||
// TMA bulk copy which is faster for large outputs.
|
||||
|
||||
if (rg == 0) { // Only row 0 for decode T=1
|
||||
// Write to GMEM directly
|
||||
int d0 = col * 4 + 0;
|
||||
int d1 = col * 4 + 1;
|
||||
int d2 = col * 4 + 2;
|
||||
int d3 = col * 4 + 3;
|
||||
if (d0 < HD) gmem_o[d0] = f32_to_bf16(r0 * inv_sum);
|
||||
if (d1 < HD) gmem_o[d1] = f32_to_bf16(r1 * inv_sum);
|
||||
if (d2 < HD) gmem_o[d2] = f32_to_bf16(r2 * inv_sum);
|
||||
if (d3 < HD) gmem_o[d3] = f32_to_bf16(r3 * inv_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// FMHA Decode Kernel with TMEM + Correction Epilogue
|
||||
// =====================================================================
|
||||
|
||||
constexpr int WARP = 32;
|
||||
constexpr int NTHREADS = 192;
|
||||
constexpr int NWARPS = 6;
|
||||
|
||||
__device__ __forceinline__ float wmax(float v) {
|
||||
for(int o=16;o>0;o>>=1) v=fmaxf(v,__shfl_xor_sync(0xFFFFFFFF,v,o)); return v;
|
||||
}
|
||||
__device__ __forceinline__ float wsum(float v) {
|
||||
for(int o=16;o>0;o>>=1) v+=__shfl_xor_sync(0xFFFFFFFF,v,o); return v;
|
||||
}
|
||||
|
||||
/**
|
||||
* FMHA decode with TMEM accumulator and one-way correction epilogue.
|
||||
*
|
||||
* Phase 2: Uses TMEM for O accumulation, correction epilogue for normalize.
|
||||
* QK and PV still computed in registers (scalar) — tcgen05.mma comes in Phase 3.
|
||||
*
|
||||
* The key innovation: O rescale happens in REGISTERS between KV tiles,
|
||||
* loading from TMEM → registers → multiply → store back to TMEM.
|
||||
* This is the D1.5 fix that CuTeDSL couldn't do (TMEM round-trip broken).
|
||||
*/
|
||||
template<int HD>
|
||||
__global__ void __launch_bounds__(NTHREADS)
|
||||
fmha_decode_tmem(
|
||||
const bf16_t* __restrict__ q,
|
||||
const bf16_t* __restrict__ k,
|
||||
const bf16_t* __restrict__ v,
|
||||
bf16_t* __restrict__ o,
|
||||
int bstride_q, int bstride_kv, int bstride_o,
|
||||
int s_k, int n_comp, int swa_len,
|
||||
float scale,
|
||||
const float* __restrict__ attn_sink,
|
||||
float* __restrict__ lse_out
|
||||
) {
|
||||
const int head = blockIdx.y;
|
||||
const int batch = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / WARP;
|
||||
const int lane = tid % WARP;
|
||||
|
||||
const bf16_t* qh = q + batch * bstride_q + head * HD;
|
||||
const bf16_t* kb = k + batch * bstride_kv;
|
||||
const bf16_t* vb = v + batch * bstride_kv;
|
||||
bf16_t* oh = o + batch * bstride_o + head * HD;
|
||||
|
||||
// TMEM allocation for O accumulator
|
||||
// O needs HD FP32 values (for T=1 decode)
|
||||
// TMEM columns: each holds 128 FP32 values (128 rows × 1 FP32 per row per column)
|
||||
// For HD=64: 64 columns needed, but TMEM loads 4 FP32 per column per row-group
|
||||
// So we need ceil(HD/4) = 16 columns for HD=64
|
||||
const int tmem_o_cols = (HD + 3) / 4; // 4 FP32 per tcgen05.ld per column
|
||||
int tmem_n = 1; while(tmem_n < tmem_o_cols + 4) tmem_n *= 2; // round to power of 2
|
||||
uint32_t tb = 0;
|
||||
if (wid == 0 && lane == 0) tb = tmem_alloc(tmem_n);
|
||||
tb = __shfl_sync(0xFFFFFFFF, tb, 0);
|
||||
const uint32_t to = tb; // O starts at TMEM base
|
||||
|
||||
// SMEM for Q, row_sums
|
||||
extern __shared__ char sbuf[];
|
||||
float* sQ = (float*)sbuf; // HD floats
|
||||
float* sRowSums = (float*)(sbuf + HD * sizeof(float)); // 1 float (row_sum for T=1)
|
||||
|
||||
for (int d = tid; d < HD; d += NTHREADS) sQ[d] = bf16_to_f32(qh[d]);
|
||||
__syncthreads();
|
||||
|
||||
// Online softmax with O rescale in TMEM
|
||||
float row_max = -INFINITY;
|
||||
float row_sum = 0.0f;
|
||||
|
||||
// Initialize TMEM O to zero
|
||||
for (int col = tid; col < tmem_o_cols; col += NTHREADS) {
|
||||
for (int rg = 0; rg < 8; rg++) { // 8 row-groups of 16 rows each
|
||||
tmem_store_col(to + col, rg, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
tmem_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Process KV positions (single-thread for Phase 2 correctness)
|
||||
if (tid == 0) {
|
||||
for (int c = 0; c < s_k; c++) {
|
||||
float s_val = 0.0f;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
s_val += sQ[d] * bf16_to_f32(kb[c * HD + d]);
|
||||
}
|
||||
s_val *= scale;
|
||||
|
||||
if (swa_len > 0 && c >= n_comp + swa_len) s_val = -INFINITY;
|
||||
|
||||
// Online softmax with O rescale in TMEM
|
||||
float new_max = fmaxf(row_max, s_val);
|
||||
if (new_max > row_max) {
|
||||
float rescale = expf(row_max - new_max);
|
||||
|
||||
// D1.5 FIX: Rescale O in TMEM
|
||||
// Load O from TMEM → multiply by rescale → store back
|
||||
// This is the one-way path: TMEM → regs → multiply → TMEM
|
||||
// (NOT a round-trip with mismatched atoms — we use the SAME
|
||||
// tcgen05.ld + tcgen05.st pair, which IS correct for same-column ops)
|
||||
for (int col = 0; col < tmem_o_cols; col++) {
|
||||
float r0, r1, r2, r3;
|
||||
tmem_load_col(to + col, 0, r0, r1, r2, r3);
|
||||
r0 *= rescale; r1 *= rescale; r2 *= rescale; r3 *= rescale;
|
||||
tmem_store_col(to + col, 0, r0, r1, r2, r3);
|
||||
}
|
||||
tmem_fence();
|
||||
|
||||
row_sum *= rescale;
|
||||
row_max = new_max;
|
||||
}
|
||||
|
||||
float p_val = expf(s_val - row_max);
|
||||
row_sum += p_val;
|
||||
|
||||
// P@V: accumulate p_val * V[:, c] into TMEM O
|
||||
for (int col = 0; col < tmem_o_cols; col++) {
|
||||
int d0 = col * 4 + 0;
|
||||
int d1 = col * 4 + 1;
|
||||
int d2 = col * 4 + 2;
|
||||
int d3 = col * 4 + 3;
|
||||
|
||||
float v0 = (d0 < HD) ? bf16_to_f32(vb[d0 * s_k + c]) : 0.0f;
|
||||
float v1 = (d1 < HD) ? bf16_to_f32(vb[d1 * s_k + c]) : 0.0f;
|
||||
float v2 = (d2 < HD) ? bf16_to_f32(vb[d2 * s_k + c]) : 0.0f;
|
||||
float v3 = (d3 < HD) ? bf16_to_f32(vb[d3 * s_k + c]) : 0.0f;
|
||||
|
||||
// Load current O, add p*V, store back
|
||||
float r0, r1, r2, r3;
|
||||
tmem_load_col(to + col, 0, r0, r1, r2, r3);
|
||||
r0 += p_val * v0;
|
||||
r1 += p_val * v1;
|
||||
r2 += p_val * v2;
|
||||
r3 += p_val * v3;
|
||||
tmem_store_col(to + col, 0, r0, r1, r2, r3);
|
||||
}
|
||||
tmem_fence();
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Store row_sum for the epilogue
|
||||
if (tid == 0) sRowSums[0] = row_sum;
|
||||
__syncthreads();
|
||||
|
||||
// =================================================================
|
||||
// One-way Correction Epilogue: TMEM → regs → normalize → BF16 → GMEM
|
||||
// =================================================================
|
||||
// This is the key pattern from the MoE kernel, adapted for FMHA:
|
||||
// 1. tcgen05.ld: Load O from TMEM to registers (FP32)
|
||||
// 2. Divide by row_sum in registers (normalize)
|
||||
// 3. cvt.rn.bf16.f32: Cast to BF16 in registers
|
||||
// 4. st.global: Write to GMEM
|
||||
//
|
||||
// Future (NVFP4-1.2): Step 2.5 adds FP4 amax + pack in registers
|
||||
// Future (D2 multi-CTA): Step 4 uses TMA store with flat_divide coordinates
|
||||
|
||||
if (tid == 0) {
|
||||
float inv_sum = 1.0f / sRowSums[0];
|
||||
for (int col = 0; col < tmem_o_cols; col++) {
|
||||
float r0, r1, r2, r3;
|
||||
tmem_load_col(to + col, 0, r0, r1, r2, r3);
|
||||
|
||||
// Normalize
|
||||
r0 *= inv_sum; r1 *= inv_sum; r2 *= inv_sum; r3 *= inv_sum;
|
||||
|
||||
// Cast to BF16 and write to GMEM
|
||||
int d0 = col * 4 + 0;
|
||||
int d1 = col * 4 + 1;
|
||||
int d2 = col * 4 + 2;
|
||||
int d3 = col * 4 + 3;
|
||||
if (d0 < HD) oh[d0] = f32_to_bf16(r0);
|
||||
if (d1 < HD) oh[d1] = f32_to_bf16(r1);
|
||||
if (d2 < HD) oh[d2] = f32_to_bf16(r2);
|
||||
if (d3 < HD) oh[d3] = f32_to_bf16(r3);
|
||||
}
|
||||
}
|
||||
|
||||
// LSE
|
||||
if (lse_out && tid == 0) {
|
||||
lse_out[batch * gridDim.y + head] = logf(row_sum) + row_max;
|
||||
}
|
||||
|
||||
// TMEM dealloc
|
||||
if (wid == 0 && lane == 0) tmem_dealloc(tb, tmem_n);
|
||||
}
|
||||
|
||||
} // namespace dsv4::kernels::attention
|
||||
@@ -1,172 +1,156 @@
|
||||
/**
|
||||
* Standalone CUDA test for FMHA SM100 decode kernel.
|
||||
* Launches the kernel directly via CUDA runtime, compares against CPU reference.
|
||||
* No PyTorch or pybind11 needed — just nvcc + CUDA runtime.
|
||||
* Standalone CUDA test for FMHA SM100 — Reference + TMEM kernels.
|
||||
* Tests both the Phase 1 reference and Phase 2 TMEM+epilogue kernels.
|
||||
*/
|
||||
|
||||
#include "dsv4/kernels/attention/fmha_sm100.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_epilogue_sm100.cuh"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
#include <string.h>
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
// CPU reference: simple attention
|
||||
// CPU reference
|
||||
void attention_ref_cpu(
|
||||
const float* q, const float* k, const float* v,
|
||||
float* o, float* lse,
|
||||
float* o,
|
||||
int B, int H, int sk, int HD, float scale
|
||||
) {
|
||||
for (int b = 0; b < B; b++) {
|
||||
for (int h = 0; h < H; h++) {
|
||||
const float* qh = q + (b * H + h) * HD;
|
||||
const float* kb = k + b * sk * HD;
|
||||
const float* vb = v + b * HD * sk;
|
||||
float* oh = o + (b * H + h) * HD;
|
||||
const float* qh = q + (b*H+h)*HD;
|
||||
const float* kb = k + b*sk*HD;
|
||||
const float* vb = v + b*HD*sk;
|
||||
float* oh = o + (b*H+h)*HD;
|
||||
|
||||
// S = Q @ K^T * scale
|
||||
float* s = (float*)malloc(sk * sizeof(float));
|
||||
float* s = (float*)malloc(sk*sizeof(float));
|
||||
float s_max = -FLT_MAX;
|
||||
for (int c = 0; c < sk; c++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < HD; d++) dot += qh[d] * kb[c * HD + d];
|
||||
for (int d = 0; d < HD; d++) dot += qh[d] * kb[c*HD+d];
|
||||
s[c] = dot * scale;
|
||||
s_max = fmaxf(s_max, s[c]);
|
||||
}
|
||||
|
||||
// Softmax
|
||||
float sum = 0.0f;
|
||||
for (int c = 0; c < sk; c++) {
|
||||
s[c] = expf(s[c] - s_max);
|
||||
sum += s[c];
|
||||
}
|
||||
for (int c = 0; c < sk; c++) { s[c] = expf(s[c] - s_max); sum += s[c]; }
|
||||
for (int c = 0; c < sk; c++) s[c] /= sum;
|
||||
|
||||
// O = S @ V
|
||||
for (int d = 0; d < HD; d++) {
|
||||
oh[d] = 0.0f;
|
||||
for (int c = 0; c < sk; c++) {
|
||||
oh[d] += s[c] * vb[d * sk + c];
|
||||
}
|
||||
for (int c = 0; c < sk; c++) oh[d] += s[c] * vb[d*sk+c];
|
||||
}
|
||||
|
||||
if (lse) lse[b * H + h] = logf(sum) + s_max;
|
||||
|
||||
free(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BF16 conversion helpers for CPU
|
||||
uint16_t f32_to_bf16_cpu(float f) {
|
||||
uint32_t u;
|
||||
memcpy(&u, &f, 4);
|
||||
uint16_t h = (uint16_t)(u >> 16);
|
||||
return h;
|
||||
uint16_t f32_to_bf16_cpu(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||||
float bf16_to_f32_cpu(uint16_t h) { uint32_t u = ((uint32_t)h)<<16; float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
float cosine_sim(const float* a, const float* b, int n) {
|
||||
float dot=0, na=0, nb=0;
|
||||
for(int i=0;i<n;i++) { dot+=a[i]*b[i]; na+=a[i]*a[i]; nb+=b[i]*b[i]; }
|
||||
float d = sqrtf(na)*sqrtf(nb);
|
||||
return d > 0 ? dot/d : 0;
|
||||
}
|
||||
|
||||
float bf16_to_f32_cpu(uint16_t h) {
|
||||
uint32_t u = ((uint32_t)h) << 16;
|
||||
float f;
|
||||
memcpy(&f, &u, 4);
|
||||
return f;
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== FMHA SM100 Decode Kernel Test ===\n");
|
||||
|
||||
const int B = 1, H = 1, HD = 64, sk = 128;
|
||||
const float scale = 1.0f / sqrtf((float)HD);
|
||||
const int smem = 128 * HD * 2 * sizeof(uint16_t) + 1024; // K + V + slack
|
||||
|
||||
// Allocate host memory
|
||||
float *hq = (float*)malloc(B * H * HD * sizeof(float));
|
||||
float *hk = (float*)malloc(B * sk * HD * sizeof(float));
|
||||
float *hv = (float*)malloc(B * HD * sk * sizeof(float));
|
||||
float *ho_ref = (float*)malloc(B * H * HD * sizeof(float));
|
||||
|
||||
// Init with random data
|
||||
srand(42);
|
||||
for (int i = 0; i < B * H * HD; i++) hq[i] = (float)rand() / RAND_MAX - 0.5f;
|
||||
for (int i = 0; i < B * sk * HD; i++) hk[i] = (float)rand() / RAND_MAX - 0.5f;
|
||||
for (int i = 0; i < B * HD * sk; i++) hv[i] = (float)rand() / RAND_MAX - 0.5f;
|
||||
|
||||
// CPU reference
|
||||
attention_ref_cpu(hq, hk, hv, ho_ref, NULL, B, H, sk, HD, scale);
|
||||
|
||||
// Convert to BF16
|
||||
uint16_t *hqb = (uint16_t*)malloc(B * H * HD * sizeof(uint16_t));
|
||||
uint16_t *hkb = (uint16_t*)malloc(B * sk * HD * sizeof(uint16_t));
|
||||
uint16_t *hvb = (uint16_t*)malloc(B * HD * sk * sizeof(uint16_t));
|
||||
uint16_t *hob = (uint16_t*)malloc(B * H * HD * sizeof(uint16_t));
|
||||
|
||||
for (int i = 0; i < B * H * HD; i++) hqb[i] = f32_to_bf16_cpu(hq[i]);
|
||||
for (int i = 0; i < B * sk * HD; i++) hkb[i] = f32_to_bf16_cpu(hk[i]);
|
||||
for (int i = 0; i < B * HD * sk; i++) hvb[i] = f32_to_bf16_cpu(hv[i]);
|
||||
|
||||
// Allocate GPU memory
|
||||
uint16_t *dq, *dk, *dv, *do_;
|
||||
float *d_lse;
|
||||
cudaMalloc(&dq, B * H * HD * sizeof(uint16_t));
|
||||
cudaMalloc(&dk, B * sk * HD * sizeof(uint16_t));
|
||||
cudaMalloc(&dv, B * HD * sk * sizeof(uint16_t));
|
||||
cudaMalloc(&do_, B * H * HD * sizeof(uint16_t));
|
||||
cudaMalloc(&d_lse, B * H * sizeof(float));
|
||||
|
||||
// Copy to GPU
|
||||
cudaMemcpy(dq, hqb, B * H * HD * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(dk, hkb, B * sk * HD * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(dv, hvb, B * HD * sk * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemset(do_, 0, B * H * HD * sizeof(uint16_t));
|
||||
|
||||
// Launch kernel
|
||||
int test_kernel(const char* name, int HD, int sk, float scale,
|
||||
uint16_t* dq, uint16_t* dk, uint16_t* dv, uint16_t* do_gpu,
|
||||
float* d_lse, float* ho_ref, int B, int H) {
|
||||
dim3 grid(1, H, B);
|
||||
dim3 block(NTHREADS);
|
||||
int smem = (HD * sizeof(float)) + 128 + 1024; // Q + row_sums + slack
|
||||
|
||||
printf("Launching fmha_decode_ref<%d> <<<(%d,%d,%d), %d>>>...\n", HD, grid.x, grid.y, grid.z, block.x);
|
||||
cudaMemset(do_gpu, 0, B*H*HD*sizeof(uint16_t));
|
||||
|
||||
fmha_decode_ref<HD><<<grid, block, smem>>>(
|
||||
dq, dk, dv, do_,
|
||||
H * HD, sk * HD, H * HD,
|
||||
sk, 0, 0, scale, NULL, d_lse
|
||||
);
|
||||
if (strcmp(name, "reference") == 0) {
|
||||
fmha_decode_ref<HD><<<grid, block, smem>>>(
|
||||
dq, dk, dv, do_gpu,
|
||||
H*HD, sk*HD, H*HD,
|
||||
sk, 0, 0, scale, NULL, d_lse);
|
||||
} else {
|
||||
fmha_decode_tmem<HD><<<grid, block, smem>>>(
|
||||
dq, dk, dv, do_gpu,
|
||||
H*HD, sk*HD, H*HD,
|
||||
sk, 0, 0, scale, NULL, d_lse);
|
||||
}
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("❌ Kernel launch failed: %s\n", cudaGetErrorString(err));
|
||||
return 1;
|
||||
}
|
||||
printf("✅ Kernel launched successfully!\n");
|
||||
|
||||
// Copy result back
|
||||
cudaMemcpy(hob, do_, B * H * HD * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Compare with reference
|
||||
float cos_sim = 0.0f, norm_a = 0.0f, norm_b = 0.0f;
|
||||
for (int i = 0; i < B * H * HD; i++) {
|
||||
float gpu_val = bf16_to_f32_cpu(hob[i]);
|
||||
float ref_val = ho_ref[i];
|
||||
cos_sim += gpu_val * ref_val;
|
||||
norm_a += gpu_val * gpu_val;
|
||||
norm_b += ref_val * ref_val;
|
||||
}
|
||||
float denom = sqrtf(norm_a) * sqrtf(norm_b);
|
||||
if (denom > 0) cos_sim /= denom;
|
||||
|
||||
printf("\nhd=%d, s_k=%d: cos %.6f %s\n", HD, sk, cos_sim, cos_sim > 0.999f ? "✅ PASS" : "❌ FAIL");
|
||||
|
||||
if (cos_sim < 0.999f) {
|
||||
printf("First 8 values (GPU vs Ref):\n");
|
||||
for (int i = 0; i < 8; i++) {
|
||||
printf(" [%d] GPU=%f Ref=%f\n", i, bf16_to_f32_cpu(hob[i]), ho_ref[i]);
|
||||
}
|
||||
printf(" ❌ %s: kernel failed: %s\n", name, cudaGetErrorString(err));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
cudaFree(dq); cudaFree(dk); cudaFree(dv); cudaFree(do_); cudaFree(d_lse);
|
||||
free(hq); free(hk); free(hv); free(ho_ref);
|
||||
free(hqb); free(hkb); free(hvb); free(hob);
|
||||
// Copy result and compare
|
||||
uint16_t* hob = (uint16_t*)malloc(B*H*HD*sizeof(uint16_t));
|
||||
cudaMemcpy(hob, do_gpu, B*H*HD*sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
return cos_sim > 0.999f ? 0 : 1;
|
||||
float* ho_gpu = (float*)malloc(B*H*HD*sizeof(float));
|
||||
for (int i = 0; i < B*H*HD; i++) ho_gpu[i] = bf16_to_f32_cpu(hob[i]);
|
||||
|
||||
float cos = cosine_sim(ho_gpu, ho_ref, B*H*HD);
|
||||
int pass = cos > 0.999f;
|
||||
printf(" %s hd=%d s_k=%d: cos %.6f %s\n", name, HD, sk, cos, pass ? "✅" : "❌");
|
||||
|
||||
if (!pass) {
|
||||
printf(" GPU[:4] = %.6f %.6f %.6f %.6f\n", ho_gpu[0], ho_gpu[1], ho_gpu[2], ho_gpu[3]);
|
||||
printf(" Ref[:4] = %.6f %.6f %.6f %.6f\n", ho_ref[0], ho_ref[1], ho_ref[2], ho_ref[3]);
|
||||
}
|
||||
|
||||
free(hob); free(ho_gpu);
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== FMHA SM100 Decode Kernel Test Suite ===\n\n");
|
||||
|
||||
int all_pass = 1;
|
||||
int head_dims[] = {64, 128};
|
||||
int s_ks[] = {128};
|
||||
|
||||
for (int t = 0; t < 2; t++) {
|
||||
int HD = head_dims[t];
|
||||
int sk = s_ks[0];
|
||||
float scale = 1.0f / sqrtf((float)HD);
|
||||
int B = 1, H = 1;
|
||||
|
||||
printf("--- hd=%d, s_k=%d ---\n", HD, sk);
|
||||
|
||||
// Alloc
|
||||
float *hq=(float*)malloc(B*H*HD*4), *hk=(float*)malloc(B*sk*HD*4);
|
||||
float *hv=(float*)malloc(B*HD*sk*4), *ho_ref=(float*)malloc(B*H*HD*4);
|
||||
|
||||
srand(42);
|
||||
for(int i=0;i<B*H*HD;i++) hq[i]=(float)rand()/RAND_MAX-0.5f;
|
||||
for(int i=0;i<B*sk*HD;i++) hk[i]=(float)rand()/RAND_MAX-0.5f;
|
||||
for(int i=0;i<B*HD*sk;i++) hv[i]=(float)rand()/RAND_MAX-0.5f;
|
||||
|
||||
attention_ref_cpu(hq,hk,hv,ho_ref,B,H,sk,HD,scale);
|
||||
|
||||
uint16_t *hqb=(uint16_t*)malloc(B*H*HD*2), *hkb=(uint16_t*)malloc(B*sk*HD*2);
|
||||
uint16_t *hvb=(uint16_t*)malloc(B*HD*sk*2);
|
||||
for(int i=0;i<B*H*HD;i++) hqb[i]=f32_to_bf16_cpu(hq[i]);
|
||||
for(int i=0;i<B*sk*HD;i++) hkb[i]=f32_to_bf16_cpu(hk[i]);
|
||||
for(int i=0;i<B*HD*sk;i++) hvb[i]=f32_to_bf16_cpu(hv[i]);
|
||||
|
||||
uint16_t *dq,*dk,*dv,*do_;
|
||||
float *d_lse;
|
||||
cudaMalloc(&dq,B*H*HD*2); cudaMalloc(&dk,B*sk*HD*2);
|
||||
cudaMalloc(&dv,B*HD*sk*2); cudaMalloc(&do_,B*H*HD*2);
|
||||
cudaMalloc(&d_lse,B*H*4);
|
||||
cudaMemcpy(dq,hqb,B*H*HD*2,cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(dk,hkb,B*sk*HD*2,cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(dv,hvb,B*HD*sk*2,cudaMemcpyHostToDevice);
|
||||
|
||||
all_pass &= test_kernel("reference", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
|
||||
all_pass &= test_kernel("tmem_epilogue", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
|
||||
|
||||
cudaFree(dq);cudaFree(dk);cudaFree(dv);cudaFree(do_);cudaFree(d_lse);
|
||||
free(hq);free(hk);free(hv);free(ho_ref);free(hqb);free(hkb);free(hvb);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("%s\n", all_pass ? "✅ ALL TESTS PASSED!" : "❌ SOME TESTS FAILED");
|
||||
return all_pass ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user