diff --git a/tests/unit/test_fmha_hd64_smem_p.cu b/tests/unit/test_fmha_hd64_smem_p.cu new file mode 100644 index 00000000..dd09dfc2 --- /dev/null +++ b/tests/unit/test_fmha_hd64_smem_p.cu @@ -0,0 +1,240 @@ +/** + * Full FMHA HD=64, SK=128 — PV via SS MMA (SMEM-P approach) + * + * Extends test_fmha_v5 (HD=16) to HD=64. + * 4 QK K-tiles, 8 PV K-tiles, per-K-tile P fill. + */ + +#include +#include +#include +#include +#include + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = 64, SK = 128, BLOCK_MN = 128; +constexpr int NKT_QK = HD / MMA_K_BF16; // 4 +constexpr int NKT_PV = SK / MMA_K_BF16; // 8 +constexpr int TMEM_N = 128; +constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; + +__global__ void __launch_bounds__(128) +test_fmha_hd64_smem_p(const bf16_t* q, const bf16_t* k, const bf16_t* v, + bf16_t* o_out, float* o_scalar, float scale) +{ + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + NKT_QK * TILE_SZ; // 4 Q K-tiles + bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + NKT_QK * TILE_SZ) + 127) & ~(uintptr_t)127); + bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127); + float* s_p_vals = (float*)(sV + NKT_PV * 256); + + // Load Q K-tiles + for (int kt = 0; kt < NKT_QK; kt++) { + bf16_t* sq = sQ0 + kt * TILE_SZ; + for (int i = tid; i < TILE_SZ; i += 128) sq[i] = 0; + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + sq[ck * 16 * 64 + lc] = q[kt * MMA_K_BF16 + d]; // Row 0 only + } + } + + // Load K K-tiles + for (int kt = 0; kt < NKT_QK; kt++) { + bf16_t* sk = sK0 + kt * TILE_SZ; + for (int i = tid; i < TILE_SZ; i += 128) sk[i] = 0; + for (int r = 0; r < SK; r++) { + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + int tmn = r / 8, lr = r % 8; + sk[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d]; + } + } + } + + // Load V K-tiles + for (int kt = 0; kt < NKT_PV; kt++) { + bf16_t* sv = sV + kt * 256; + for (int i = tid; i < 256; i += 128) sv[i] = 0; + for (int d = tid; d < HD; d += 128) { + for (int lr = 0; lr < MMA_K_BF16; lr++) { + int r = kt * MMA_K_BF16 + lr; + int g_mn = d / 8, g_k = lr / 8; + int llr = d % 8, lc = lr % 8; + sv[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v[d * SK + r]; + } + } + } + __syncthreads(); + + // TMEM alloc + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + __syncthreads(); + uint32_t tb = *sTmemBase; + + // ===== QK GEMM (4 K-tiles, accumulate) ===== + { + uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN); + for (int kt = 0; kt < NKT_QK; kt++) { + bf16_t* sq = sQ0 + kt * TILE_SZ; + bf16_t* sk = sK0 + kt * TILE_SZ; + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sq), BLOCK_MN); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sk), BLOCK_MN); + if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + + // ===== Softmax ===== + if (wid == 0) { + float s_vals[SK], row_max = -INFINITY; + for (int n = 0; n < SK / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) { + s_vals[n*8+c] = tmp[c] * scale * 2.0f; // ×2 to undo QK MMA 0.5 scale + row_max = fmaxf(row_max, tmp[c] * scale * 2.0f); + } + } + row_max = wmax(row_max); + float row_sum = 0.0f; + if (lane == 0) for (int j=0;j 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + + // ===== Epilogue ===== + if (wid == 0) { + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; + } + if (lane == 0) for (int d=0;d>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost); + + printf("O[0..7] MMA: "); for(int d=0;d<8;d++) printf("%.6f ",bf16_to_f32_host(h_o[d])); printf("\n"); + printf("O[0..7] ref: "); for(int d=0;d<8;d++) printf("%.6f ",h_o_scalar[d]); printf("\n"); + + float max_diff=0, max_val=0; + for (int d=0;d0 ? max_diff/max_val : max_diff; + float cos_sim=0,na=0,nb=0; + for (int d=0;d 0.999f ? "PASSED" : "FAILED"); + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_o_scalar); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_o_scalar); + return cos_sim > 0.999f ? 0 : 1; +}