diff --git a/tests/unit/test_fmha_smem_p.cu b/tests/unit/test_fmha_smem_p.cu new file mode 100644 index 00000000..62fdccfe --- /dev/null +++ b/tests/unit/test_fmha_smem_p.cu @@ -0,0 +1,265 @@ +/** + * Full FMHA HD=16, SK=128 — PV via SS MMA (SMEM-P approach) + * + * Pipeline: Q×K^T (SS) → softmax (TMEM read → SMEM write) → P×V (SS) → epilogue + * + * Key insight: the tcgen05.mma TS A-operand TMEM layout (Layout A) does NOT + * match the 32x32b store format. Using SS MMA for both QK and PV avoids the + * TMEM layout issue entirely, because both operands come from SMEM where we + * control the canonical K-major layout. + * + * This is the SMEM-P approach, similar to what CuTeDSL uses for hd > 64, + * but applied at all head dims for the raw CUDA path. + * + * SMEM layout: + * sQ: (128, 16) — Q K-tile + * sK: (128, 16) — K K-tile + * sP: (128, 128) — softmax output, written in canonical K-major layout + * sV: 8 × (16, 16) — V K-tiles + */ + +#include +#include +#include +#include +#include + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = 16, SK = 128, BLOCK_MN = 128; +constexpr int NKT_QK = HD / MMA_K_BF16; // 1 +constexpr int NKT_PV = SK / MMA_K_BF16; // 8 +constexpr int TMEM_N = 128; // Just S and O, no P in TMEM +constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; + +__global__ void __launch_bounds__(128) +test_fmha_smem_p(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, + const bf16_t* __restrict__ v, bf16_t* __restrict__ o_out, + float* __restrict__ o_scalar, float scale) +{ + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + TILE_SZ; + // sP: softmax output in canonical (128, 128) layout + // (128, 128): CORES_MN=16, CORES_K=16 + // Each core: 64 BF16. Total: 16*16*64 = 16384 BF16 = 32768 bytes + bf16_t* sP = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); + // sV: 8 K-tiles of (16, 16) + bf16_t* sV = (bf16_t*)(((uintptr_t)(sP + 128 * SK) + 127) & ~(uintptr_t)127); + + // Load Q, K + write_q_to_smem(sQ0, q); + write_k_to_smem(sK0, k); + + // Load V K-tiles + for (int kt = 0; kt < NKT_PV; kt++) { + bf16_t* sv = sV + kt * 256; + for (int i = tid; i < 256; i += 128) sv[i] = 0; + for (int d = tid; d < HD; d += 128) { + for (int lr = 0; lr < MMA_K_BF16; lr++) { + int r = kt * MMA_K_BF16 + lr; + int ck = d / 8, lc = d % 8; + int tmn = lr / 8, llr = lr % 8; + int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc; + sv[dst_idx] = v[d * SK + r]; + } + } + } + __syncthreads(); + + // TMEM alloc: 128 columns for S + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + __syncthreads(); + uint32_t tb = *sTmemBase; + + // ===== STEP 1: QK GEMM (SS) ===== + { + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN); + uint32_t idesc_qk = make_idesc(BLOCK_MN, BLOCK_MN); + for (int kt = 0; kt < NKT_QK; kt++) { + if (tid == 0) umma_ss_f16(tb, dq, dk, idesc_qk, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + + // ===== STEP 2: Softmax — read S from TMEM, write P to SMEM ===== + if (wid == 0) { + float s_vals[SK], row_max = -INFINITY; + for (int n = 0; n < SK / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) { + s_vals[n*8+c] = tmp[c] * scale; + row_max = fmaxf(row_max, tmp[c] * scale); + } + } + row_max = wmax(row_max); + float row_sum = 0.0f; + if (lane == 0) for (int j=0;j 0); + if (tid == 0) umma_ss_f16(tb, dp, dv, idesc_pv, accumulate); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + + // ===== STEP 4: Epilogue — read O from TMEM ===== + if (wid == 0) { + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f; // Undo MMA 0.5 scale + } + if (lane == 0) for (int d=0;d>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost); + + printf("O[0..15] MMA: "); for(int d=0;d0 ? max_diff/max_val : max_diff; + float cos_sim=0,na=0,nb=0; + for (int d=0;d 0.999f ? "PASSED" : "FAILED"); + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_o_scalar); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_o_scalar); + return cos_sim > 0.999f ? 0 : 1; +}