Files
nvfp4-megamoe-kernel/tests/unit/test_fmha_hd64.cu

186 lines
7.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Full UMMA FMHA — HD=64, SK=128, T=1 (decode)
*
* Pipeline: Q×K^T (UMMA SS) → softmax (TMEM) → PV (register math) → output
* PV computed in registers (decode T=1): O[d] = Σ P[0,j] × V[d,j]
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 64, SK = 128, NKT = HD / MMA_K_BF16;
constexpr int BLOCK_MN = 128, TILE_SZ = BLOCK_MN * MMA_K_BF16, CORES_MN = BLOCK_MN / 8;
__global__ void __launch_bounds__(128)
test_fmha_hd64(const bf16_t* q, const bf16_t* k, const bf16_t* v,
bf16_t* o_out, float* o_scalar, float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + NKT * TILE_SZ;
// Load Q and K (same as working QK test)
for (int i = tid; i < NKT * TILE_SZ; i += 128) { sQ0[i] = 0; sK0[i] = 0; }
for (int kt = 0; kt < NKT; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
sq[ck * CORES_MN * 64 + lc] = q[kt * MMA_K_BF16 + d];
}
bf16_t* sk = sK0 + kt * TILE_SZ;
for (int r = 0; r < SK; r++) {
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sk[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d];
}
}
}
__syncthreads();
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// QK GEMM
bf16_t* sQ_arr[4] = {sQ0, sQ0+TILE_SZ, sQ0+2*TILE_SZ, sQ0+3*TILE_SZ};
bf16_t* sK_arr[4] = {sK0, sK0+TILE_SZ, sK0+2*TILE_SZ, sK0+3*TILE_SZ};
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < NKT; kt++) {
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ_arr[kt]), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK_arr[kt]), BLOCK_MN);
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Softmax + PV: warp 0 computes softmax and then PV in registers
if (wid == 0) {
float s_vals[SK], row_max = -INFINITY;
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) { s_vals[n*8+c] = tmp[c]*scale; row_max = fmaxf(row_max, tmp[c]*scale); }
}
row_max = wmax(row_max);
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK;j++) { s_vals[j] = expf(s_vals[j]-row_max); row_sum += s_vals[j]; }
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
// Write P to TMEM
for (int n = 0; n < SK / 8; n++) {
float p0=(lane==0)?s_vals[n*8+0]:0, p1=(lane==0)?s_vals[n*8+1]:0;
float p2=(lane==0)?s_vals[n*8+2]:0, p3=(lane==0)?s_vals[n*8+3]:0;
float p4=(lane==0)?s_vals[n*8+4]:0, p5=(lane==0)?s_vals[n*8+5]:0;
float p6=(lane==0)?s_vals[n*8+6]:0, p7=(lane==0)?s_vals[n*8+7]:0;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb+n*8),"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
}
tmem_fence_store();
// PV: O[d] = Σ P[0,j] × V[d,j] — s_vals still in scope
if (lane == 0) {
for (int d = 0; d < HD; d++) {
float ov = 0.0f;
for (int j = 0; j < SK; j++)
ov += s_vals[j] * bf16_to_f32(v[d * SK + j]);
o_out[d] = f32_to_bf16(ov);
}
}
}
__syncthreads();
// Scalar reference
if (tid == 0) {
float s[SK];
for (int j=0;j<SK;j++) {
float dot = 0.0f;
for (int d=0;d<HD;d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j*HD+d]);
s[j] = dot * scale;
}
float mx = -INFINITY;
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
for (int j=0;j<SK;j++) s[j] /= sm;
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);
o_scalar[d] = ov;
}
}
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== Full UMMA FMHA HD=64 (decode) ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
bf16_t* h_o = (bf16_t*)calloc(HD, sizeof(bf16_t));
float* h_o_scalar = (float*)calloc(HD, sizeof(float));
srand(42);
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
bf16_t *d_q,*d_k,*d_v,*d_o; float *d_o_scalar;
cudaMalloc(&d_q, HD*sizeof(bf16_t));
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
cudaMalloc(&d_o, HD*sizeof(bf16_t));
cudaMalloc(&d_o_scalar, HD*sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4+16 + 2*NKT*TILE_SZ*sizeof(bf16_t) + 256 + 127) & ~127;
test_fmha_hd64<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost);
cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost);
printf("O[0..7] MMA: "); for(int d=0;d<8;d++) printf("%.6f ",bf16_to_f32_host(h_o[d])); printf("\n");
printf("O[0..7] ref: "); for(int d=0;d<8;d++) printf("%.6f ",h_o_scalar[d]); printf("\n");
float max_diff=0, max_val=0;
for (int d=0;d<HD;d++) {
max_diff = fmaxf(max_diff, fabsf(bf16_to_f32_host(h_o[d]) - h_o_scalar[d]));
max_val = fmaxf(max_val, fabsf(h_o_scalar[d]));
}
float rel_err = max_val>0 ? max_diff/max_val : max_diff;
float cos_sim=0, na=0, nb=0;
for (int d=0;d<HD;d++) { float a=bf16_to_f32_host(h_o[d]),b=h_o_scalar[d]; cos_sim+=a*b; na+=a*a; nb+=b*b; }
cos_sim /= (sqrtf(na)*sqrtf(nb)+1e-10f);
printf("Max rel err: %.8f | cosine: %.8f\n", rel_err, cos_sim);
printf("Test %s\n", cos_sim > 0.999f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_o_scalar);
free(h_q); free(h_k); free(h_v); free(h_o); free(h_o_scalar);
return cos_sim > 0.999f ? 0 : 1;
}