Full FMHA HD=16 with PV GEMM via tcgen05.mma TS
This commit is contained in:
244
tests/unit/test_fmha_ts_full.cu
Normal file
244
tests/unit/test_fmha_ts_full.cu
Normal file
@@ -0,0 +1,244 @@
|
||||
/**
|
||||
* Full FMHA HD=16, SK=128 — with PV GEMM via tcgen05.mma TS
|
||||
*
|
||||
* Pipeline: Q×K^T (SS) → softmax (TMEM) → P×V (TS) → epilogue
|
||||
*
|
||||
* Key design:
|
||||
* - TMEM: 256 columns. Cols 0-127 = S/P. Cols 128-143 = O.
|
||||
* - QK: SS MMA (1 K-tile at HD=16)
|
||||
* - Softmax: read S from TMEM, compute P, write P back to TMEM
|
||||
* - PV: TS MMA, 8 K-tiles (SK=128 → SK/16 = 8 K-tiles)
|
||||
* - Epilogue: read O from TMEM, apply MMA scale (2×), write to GMEM
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
constexpr int HD = 16, SK = 128, BLOCK_MN = 128;
|
||||
constexpr int NKT_QK = HD / MMA_K_BF16; // 1 K-tile for QK at HD=16
|
||||
constexpr int NKT_PV = SK / MMA_K_BF16; // 8 K-tiles for PV at SK=128
|
||||
constexpr int TMEM_P = 128;
|
||||
constexpr int TMEM_O = 32; // 16 cols → round to 32 (power of 2, min 32)
|
||||
constexpr int TMEM_N = 256;
|
||||
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; // 2048 BF16 per K-tile
|
||||
|
||||
__global__ void __launch_bounds__(128)
|
||||
test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
const bf16_t* __restrict__ v, bf16_t* __restrict__ o_out,
|
||||
float* __restrict__ o_scalar, float scale)
|
||||
{
|
||||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||||
|
||||
// SMEM: tmem_base + sQ(128,16) + sK(128,16) + sV(8 × 16×16)
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sK0 = sQ0 + TILE_SZ; // 2048 BF16 per K-tile
|
||||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
|
||||
// Load Q: (1, HD) → (128, 16) canonical
|
||||
write_q_to_smem<HD>(sQ0, q);
|
||||
// Load K: (SK, HD) → (128, 16) canonical
|
||||
write_k_to_smem<SK, HD>(sK0, k);
|
||||
|
||||
// Load V: (HD, SK) → 8 K-tiles of (16, 16) canonical
|
||||
// For PV: B = (SK, HD) = (128, 16). The kt-th K-tile: B[16*kt..16*kt+16, 0..16)
|
||||
// B[r, d] = V[d, r] = V[d * SK + r]
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
bf16_t* sv = sV + kt * 256; // 256 BF16 = 512 bytes per (16,16) tile
|
||||
for (int i = tid; i < 256; i += 128) sv[i] = 0;
|
||||
for (int d = tid; d < HD; d += 128) {
|
||||
for (int lr = 0; lr < MMA_K_BF16; lr++) {
|
||||
int r = kt * MMA_K_BF16 + lr;
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int tmn = lr / 8, llr = lr % 8;
|
||||
int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc;
|
||||
sv[dst_idx] = v[d * SK + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc
|
||||
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
uint32_t tb_o = tb + TMEM_P; // O at column 128
|
||||
|
||||
// ===== STEP 1: QK GEMM (SS) =====
|
||||
{
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN);
|
||||
uint32_t idesc_qk = make_idesc(BLOCK_MN, BLOCK_MN);
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// ===== STEP 2: Softmax =====
|
||||
if (wid == 0) {
|
||||
float s_vals[SK], row_max = -INFINITY;
|
||||
for (int n = 0; n < SK / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) {
|
||||
s_vals[n*8+c] = tmp[c] * scale;
|
||||
row_max = fmaxf(row_max, tmp[c] * scale);
|
||||
}
|
||||
}
|
||||
row_max = wmax(row_max);
|
||||
float row_sum = 0.0f;
|
||||
if (lane == 0) for (int j=0;j<SK;j++) {
|
||||
s_vals[j] = expf(s_vals[j] - row_max);
|
||||
row_sum += s_vals[j];
|
||||
}
|
||||
row_sum = wsum(row_sum);
|
||||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
|
||||
|
||||
// Write P to TMEM
|
||||
for (int n = 0; n < SK / 8; n++) {
|
||||
float p0=(lane==0)?s_vals[n*8+0]:0, p1=(lane==0)?s_vals[n*8+1]:0;
|
||||
float p2=(lane==0)?s_vals[n*8+2]:0, p3=(lane==0)?s_vals[n*8+3]:0;
|
||||
float p4=(lane==0)?s_vals[n*8+4]:0, p5=(lane==0)?s_vals[n*8+5]:0;
|
||||
float p6=(lane==0)?s_vals[n*8+6]:0, p7=(lane==0)?s_vals[n*8+7]:0;
|
||||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};"
|
||||
:: "r"(tb+n*8),
|
||||
"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ===== STEP 3: PV GEMM (TS) =====
|
||||
// P(128,128) × V(128,16) → O(128,16)
|
||||
// 8 K-tiles: A = P cols [16*kt..16*kt+15), B = V K-tile kt
|
||||
{
|
||||
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
|
||||
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
uint32_t tmem_a = tb + kt * MMA_K_BF16; // A from P's kt-th 16 columns
|
||||
bf16_t* sv = sV + kt * 256;
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), 16);
|
||||
|
||||
bool accumulate = (kt > 0);
|
||||
if (tid == 0) umma_ts_f16(tb_o, tmem_a, dv, idesc_pv, accumulate);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// ===== STEP 4: Epilogue — read O from TMEM =====
|
||||
// MMA output is scaled by 0.5, so multiply by 2.0
|
||||
if (wid == 0) {
|
||||
float o_vals[HD];
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb_o + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f;
|
||||
}
|
||||
if (lane == 0) for (int d=0;d<HD;d++) o_out[d] = f32_to_bf16(o_vals[d]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Scalar reference (on GPU, tid==0)
|
||||
if (tid == 0) {
|
||||
float s[SK];
|
||||
for (int j=0;j<SK;j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d=0;d<HD;d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j*HD+d]);
|
||||
s[j] = dot * scale;
|
||||
}
|
||||
float mx = -INFINITY;
|
||||
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
|
||||
float sm = 0.0f;
|
||||
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
|
||||
for (int j=0;j<SK;j++) s[j] /= sm;
|
||||
for (int d=0;d<HD;d++) {
|
||||
float ov = 0.0f;
|
||||
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);
|
||||
o_scalar[d] = ov;
|
||||
}
|
||||
}
|
||||
|
||||
if (wid == 0) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== Full FMHA HD=16 with PV GEMM (TS) ===\n");
|
||||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||||
|
||||
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
|
||||
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
|
||||
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
|
||||
bf16_t* h_o = (bf16_t*)calloc(HD, sizeof(bf16_t));
|
||||
float* h_o_scalar = (float*)calloc(HD, sizeof(float));
|
||||
|
||||
srand(42);
|
||||
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||||
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||||
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||||
|
||||
bf16_t *d_q,*d_k,*d_v,*d_o; float *d_o_scalar;
|
||||
cudaMalloc(&d_q, HD*sizeof(bf16_t));
|
||||
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
|
||||
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
|
||||
cudaMalloc(&d_o, HD*sizeof(bf16_t));
|
||||
cudaMalloc(&d_o_scalar, HD*sizeof(float));
|
||||
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// SMEM: tmem_base(4) + pad(12) + sQ(2048) + sK(2048) + sV(8*256=2048) + alignment
|
||||
int smem = (4+16 + TILE_SZ*2 + NKT_PV*256 + 256 + 127) & ~127;
|
||||
test_fmha_ts_full<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
printf("O[0..15] MMA: "); for(int d=0;d<HD;d++) printf("%.6f ",bf16_to_f32_host(h_o[d])); printf("\n");
|
||||
printf("O[0..15] ref: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_scalar[d]); printf("\n");
|
||||
|
||||
float max_diff=0, max_val=0;
|
||||
for (int d=0;d<HD;d++) {
|
||||
max_diff = fmaxf(max_diff, fabsf(bf16_to_f32_host(h_o[d]) - h_o_scalar[d]));
|
||||
max_val = fmaxf(max_val, fabsf(h_o_scalar[d]));
|
||||
}
|
||||
float rel_err = max_val>0 ? max_diff/max_val : max_diff;
|
||||
float cos_sim=0,na=0,nb=0;
|
||||
for (int d=0;d<HD;d++) { float a=bf16_to_f32_host(h_o[d]),b=h_o_scalar[d]; cos_sim+=a*b; na+=a*a; nb+=b*b; }
|
||||
cos_sim /= (sqrtf(na)*sqrtf(nb)+1e-10f);
|
||||
printf("Max rel err: %.8f | cosine: %.8f\n", rel_err, cos_sim);
|
||||
printf("Test %s\n", cos_sim > 0.999f ? "PASSED" : "FAILED");
|
||||
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_o_scalar);
|
||||
free(h_q); free(h_k); free(h_v); free(h_o); free(h_o_scalar);
|
||||
return cos_sim > 0.999f ? 0 : 1;
|
||||
}
|
||||
Reference in New Issue
Block a user