118 lines
3.9 KiB
Plaintext
118 lines
3.9 KiB
Plaintext
/**
|
|
* Standalone test for tcgen05.mma QK GEMM verification.
|
|
* Compares tensor-core Q@K^T output against CPU reference.
|
|
*/
|
|
#include "dsv4/kernels/attention/fmha_common.cuh"
|
|
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
|
#include "dsv4/kernels/attention/fmha_qk_verify.cuh"
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <math.h>
|
|
#include <float.h>
|
|
#include <string.h>
|
|
|
|
using namespace dsv4::kernels::attention;
|
|
|
|
uint16_t f32_to_bf16_cpu(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
|
float bf16_to_f32_cpu(uint16_t h) { uint32_t u = ((uint32_t)h)<<16; float f; memcpy(&f,&u,4); return f; }
|
|
|
|
float cosine_sim(const float* a, const float* b, int n) {
|
|
float dot=0, na=0, nb=0;
|
|
for(int i=0;i<n;i++) {
|
|
if(!isfinite(a[i]) || !isfinite(b[i])) continue;
|
|
dot+=a[i]*b[i]; na+=a[i]*a[i]; nb+=b[i]*b[i];
|
|
}
|
|
float d = sqrtf(na)*sqrtf(nb);
|
|
return d > 0 ? dot/d : 0;
|
|
}
|
|
|
|
int main() {
|
|
printf("=== tcgen05.mma QK GEMM Verification ===\n\n");
|
|
|
|
constexpr int HD = 64;
|
|
int s_k = 128;
|
|
int B = 1, H = 1;
|
|
float scale = 1.0f / sqrtf((float)HD);
|
|
|
|
// CPU reference: Q(1, HD) @ K^T(HD, s_k) → S(1, s_k)
|
|
float *hq=(float*)malloc(B*H*HD*4), *hk=(float*)malloc(B*s_k*HD*4);
|
|
float *hs_ref=(float*)malloc(B*s_k*4);
|
|
|
|
srand(42);
|
|
for(int i=0;i<B*H*HD;i++) hq[i]=(float)rand()/RAND_MAX-0.5f;
|
|
for(int i=0;i<B*s_k*HD;i++) hk[i]=(float)rand()/RAND_MAX-0.5f;
|
|
|
|
// Reference S = Q @ K^T * scale
|
|
for(int j=0;j<s_k;j++) {
|
|
float dot = 0;
|
|
for(int d=0;d<HD;d++) dot += hq[d] * hk[j*HD+d];
|
|
hs_ref[j] = dot * scale;
|
|
}
|
|
|
|
// Convert to BF16
|
|
uint16_t *hqb=(uint16_t*)malloc(B*H*HD*2), *hkb=(uint16_t*)malloc(B*s_k*HD*2);
|
|
for(int i=0;i<B*H*HD;i++) hqb[i]=f32_to_bf16_cpu(hq[i]);
|
|
for(int i=0;i<B*s_k*HD;i++) hkb[i]=f32_to_bf16_cpu(hk[i]);
|
|
|
|
// GPU alloc
|
|
uint16_t *dq, *dk;
|
|
float *ds_out;
|
|
cudaMalloc(&dq, B*H*HD*2);
|
|
cudaMalloc(&dk, B*s_k*HD*2);
|
|
cudaMalloc(&ds_out, s_k*4);
|
|
cudaMemcpy(dq, hqb, B*H*HD*2, cudaMemcpyHostToDevice);
|
|
cudaMemcpy(dk, hkb, B*s_k*HD*2, cudaMemcpyHostToDevice);
|
|
cudaMemset(ds_out, 0, s_k*4);
|
|
|
|
// SMEM = 4 (tmem_base) + 128B align + 128*HD*2 (sQ) + 128*HD*2 (sK) + slack
|
|
int smem = 256 + 128 * HD * 2 + 128 * HD * 2 + 4096;
|
|
|
|
dim3 grid(1, H, B);
|
|
dim3 block(NTHREADS);
|
|
|
|
printf("Running QK GEMM (hd=%d, s_k=%d)...\n", HD, s_k);
|
|
fmha_qk_verify<HD><<<grid, block, smem>>>(
|
|
dq, dk, ds_out,
|
|
H*HD, s_k*HD, s_k, scale
|
|
);
|
|
|
|
cudaError_t err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) {
|
|
printf("❌ Kernel failed: %s\n", cudaGetErrorString(err));
|
|
return 1;
|
|
}
|
|
|
|
// Read back
|
|
float *hs_gpu = (float*)malloc(s_k*4);
|
|
cudaMemcpy(hs_gpu, ds_out, s_k*4, cudaMemcpyDeviceToHost);
|
|
|
|
// Compare
|
|
// Compare — but also print raw values for debugging
|
|
printf("Scalar QK[0]: %.6f (ref: %.6f)\n", hs_gpu[0], hs_ref[0]);
|
|
printf("MMA QK[0]: %.6f\n", hs_gpu[3]);
|
|
printf("sQ[0]: %.6f sK[0]: %.6f\n", hs_gpu[1], hs_gpu[2]);
|
|
|
|
float max_diff = 0;
|
|
int nan_count = 0;
|
|
for(int j=0;j<s_k;j++) {
|
|
if(!isfinite(hs_gpu[j])) nan_count++;
|
|
else max_diff = fmaxf(max_diff, fabsf(hs_gpu[j] - hs_ref[j]));
|
|
}
|
|
|
|
float cos = cosine_sim(hs_gpu, hs_ref, s_k);
|
|
int pass = (cos > 0.99f && nan_count == 0);
|
|
printf("QK GEMM: cos %.6f max_diff %.6f nan=%d %s\n",
|
|
cos, max_diff, nan_count, pass ? "✅" : "❌");
|
|
|
|
if (!pass || max_diff > 0.1f) {
|
|
printf(" GPU[:4] = %.6f %.6f %.6f %.6f\n", hs_gpu[0], hs_gpu[1], hs_gpu[2], hs_gpu[3]);
|
|
printf(" Ref[:4] = %.6f %.6f %.6f %.6f\n", hs_ref[0], hs_ref[1], hs_ref[2], hs_ref[3]);
|
|
}
|
|
|
|
cudaFree(dq); cudaFree(dk); cudaFree(ds_out);
|
|
free(hq); free(hk); free(hs_ref); free(hs_gpu); free(hqb); free(hkb);
|
|
|
|
printf("\n%s\n", pass ? "✅ QK GEMM PASSED!" : "❌ QK GEMM FAILED");
|
|
return pass ? 0 : 1;
|
|
}
|