Files
nvfp4-megamoe-kernel/tests/unit/test_qk_mma.cu

118 lines
3.9 KiB
Plaintext

/**
* Standalone test for tcgen05.mma QK GEMM verification.
* Compares tensor-core Q@K^T output against CPU reference.
*/
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
#include "dsv4/kernels/attention/fmha_qk_verify.cuh"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <string.h>
using namespace dsv4::kernels::attention;
uint16_t f32_to_bf16_cpu(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
float bf16_to_f32_cpu(uint16_t h) { uint32_t u = ((uint32_t)h)<<16; float f; memcpy(&f,&u,4); return f; }
float cosine_sim(const float* a, const float* b, int n) {
float dot=0, na=0, nb=0;
for(int i=0;i<n;i++) {
if(!isfinite(a[i]) || !isfinite(b[i])) continue;
dot+=a[i]*b[i]; na+=a[i]*a[i]; nb+=b[i]*b[i];
}
float d = sqrtf(na)*sqrtf(nb);
return d > 0 ? dot/d : 0;
}
int main() {
printf("=== tcgen05.mma QK GEMM Verification ===\n\n");
constexpr int HD = 64;
int s_k = 128;
int B = 1, H = 1;
float scale = 1.0f / sqrtf((float)HD);
// CPU reference: Q(1, HD) @ K^T(HD, s_k) → S(1, s_k)
float *hq=(float*)malloc(B*H*HD*4), *hk=(float*)malloc(B*s_k*HD*4);
float *hs_ref=(float*)malloc(B*s_k*4);
srand(42);
for(int i=0;i<B*H*HD;i++) hq[i]=(float)rand()/RAND_MAX-0.5f;
for(int i=0;i<B*s_k*HD;i++) hk[i]=(float)rand()/RAND_MAX-0.5f;
// Reference S = Q @ K^T * scale
for(int j=0;j<s_k;j++) {
float dot = 0;
for(int d=0;d<HD;d++) dot += hq[d] * hk[j*HD+d];
hs_ref[j] = dot * scale;
}
// Convert to BF16
uint16_t *hqb=(uint16_t*)malloc(B*H*HD*2), *hkb=(uint16_t*)malloc(B*s_k*HD*2);
for(int i=0;i<B*H*HD;i++) hqb[i]=f32_to_bf16_cpu(hq[i]);
for(int i=0;i<B*s_k*HD;i++) hkb[i]=f32_to_bf16_cpu(hk[i]);
// GPU alloc
uint16_t *dq, *dk;
float *ds_out;
cudaMalloc(&dq, B*H*HD*2);
cudaMalloc(&dk, B*s_k*HD*2);
cudaMalloc(&ds_out, s_k*4);
cudaMemcpy(dq, hqb, B*H*HD*2, cudaMemcpyHostToDevice);
cudaMemcpy(dk, hkb, B*s_k*HD*2, cudaMemcpyHostToDevice);
cudaMemset(ds_out, 0, s_k*4);
// SMEM = 4 (tmem_base) + 128B align + 128*HD*2 (sQ) + 128*HD*2 (sK) + slack
int smem = 256 + 128 * HD * 2 + 128 * HD * 2 + 4096;
dim3 grid(1, H, B);
dim3 block(NTHREADS);
printf("Running QK GEMM (hd=%d, s_k=%d)...\n", HD, s_k);
fmha_qk_verify<HD><<<grid, block, smem>>>(
dq, dk, ds_out,
H*HD, s_k*HD, s_k, scale
);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("❌ Kernel failed: %s\n", cudaGetErrorString(err));
return 1;
}
// Read back
float *hs_gpu = (float*)malloc(s_k*4);
cudaMemcpy(hs_gpu, ds_out, s_k*4, cudaMemcpyDeviceToHost);
// Compare
// Compare — but also print raw values for debugging
printf("Scalar QK[0]: %.6f (ref: %.6f)\n", hs_gpu[0], hs_ref[0]);
printf("MMA QK[0]: %.6f\n", hs_gpu[3]);
printf("sQ[0]: %.6f sK[0]: %.6f\n", hs_gpu[1], hs_gpu[2]);
float max_diff = 0;
int nan_count = 0;
for(int j=0;j<s_k;j++) {
if(!isfinite(hs_gpu[j])) nan_count++;
else max_diff = fmaxf(max_diff, fabsf(hs_gpu[j] - hs_ref[j]));
}
float cos = cosine_sim(hs_gpu, hs_ref, s_k);
int pass = (cos > 0.99f && nan_count == 0);
printf("QK GEMM: cos %.6f max_diff %.6f nan=%d %s\n",
cos, max_diff, nan_count, pass ? "✅" : "❌");
if (!pass || max_diff > 0.1f) {
printf(" GPU[:4] = %.6f %.6f %.6f %.6f\n", hs_gpu[0], hs_gpu[1], hs_gpu[2], hs_gpu[3]);
printf(" Ref[:4] = %.6f %.6f %.6f %.6f\n", hs_ref[0], hs_ref[1], hs_ref[2], hs_ref[3]);
}
cudaFree(dq); cudaFree(dk); cudaFree(ds_out);
free(hq); free(hk); free(hs_ref); free(hs_gpu); free(hqb); free(hkb);
printf("\n%s\n", pass ? "✅ QK GEMM PASSED!" : "❌ QK GEMM FAILED");
return pass ? 0 : 1;
}