Files
nvfp4-megamoe-kernel/tests/unit/test_fmha_hd64_debug.cu
2026-05-28 15:46:53 +00:00

283 lines
11 KiB
Plaintext

/**
* Debug HD=64 PV precision — compare register-math PV vs SS MMA PV
* to isolate where the 0.931 cosine error comes from.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 64, SK = 128, BLOCK_MN = 128;
constexpr int NKT_QK = HD / MMA_K_BF16; // 4
constexpr int NKT_PV = SK / MMA_K_BF16; // 8
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; // 2048 BF16
constexpr int V_TILE_SZ = (HD / 8) * 2 * 64; // 1024 BF16
__global__ void __launch_bounds__(128)
test_fmha_hd64_debug(const bf16_t* q, const bf16_t* k, const bf16_t* v,
bf16_t* o_mma, float* o_ref, float* o_regmath,
float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + NKT_QK * TILE_SZ;
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + NKT_QK * TILE_SZ) + 127) & ~(uintptr_t)127);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
float* s_p_vals = (float*)(sV + NKT_PV * V_TILE_SZ);
// Load Q K-tiles
for (int kt = 0; kt < NKT_QK; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
for (int i = tid; i < TILE_SZ; i += 128) sq[i] = 0;
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
sq[ck * 16 * 64 + lc] = q[kt * MMA_K_BF16 + d];
}
}
// Load K K-tiles
for (int kt = 0; kt < NKT_QK; kt++) {
bf16_t* sk = sK0 + kt * TILE_SZ;
for (int i = tid; i < TILE_SZ; i += 128) sk[i] = 0;
for (int r = 0; r < SK; r++) {
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sk[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d];
}
}
}
// Load V K-tiles
for (int kt = 0; kt < NKT_PV; kt++) {
bf16_t* sv = sV + kt * V_TILE_SZ;
for (int i = tid; i < V_TILE_SZ; i += 128) sv[i] = 0;
for (int d = tid; d < HD; d += 128) {
for (int lr = 0; lr < MMA_K_BF16; lr++) {
int r = kt * MMA_K_BF16 + lr;
int g_mn = d / 8, g_k = lr / 8;
int llr = d % 8, lc = lr % 8;
sv[g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc] = v[d * SK + r];
}
}
}
__syncthreads();
// TMEM alloc
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// ===== QK GEMM (4 K-tiles) =====
{
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < NKT_QK; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
bf16_t* sk = sK0 + kt * TILE_SZ;
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sq), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sk), BLOCK_MN);
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// ===== Softmax (warp 0, row 0 only for T=1 decode) =====
if (wid == 0) {
float s_vals[SK], row_max = -INFINITY;
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) {
s_vals[n*8+c] = tmp[c] * scale;
row_max = fmaxf(row_max, tmp[c] * scale);
}
}
row_max = wmax(row_max);
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK;j++) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
}
__syncthreads();
// ===== PV Method 1: Register math (FP32, proven correct) =====
if (tid == 0) {
for (int d = 0; d < HD; d++) {
float ov = 0.0f;
for (int j = 0; j < SK; j++) ov += s_p_vals[j] * bf16_to_f32(v[d * SK + j]);
o_regmath[d] = ov;
}
}
// ===== PV Method 2: SS MMA with BLOCK_MN_B=64 =====
{
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD); // (128, 64)
for (int kt = 0; kt < NKT_PV; kt++) {
// Fill sPk from s_p_vals
for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0;
if (tid < 16) {
int c = tid;
int ck = c / 8, lc = c % 8;
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
}
__syncthreads();
bf16_t* sv = sV + kt * V_TILE_SZ;
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), BLOCK_MN);
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), HD);
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc_pv, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// Read MMA output from TMEM
if (wid == 0) {
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
}
if (lane == 0) for (int d=0;d<HD;d++) o_mma[d] = f32_to_bf16(o_vals[d]);
}
__syncthreads();
// Scalar reference
if (tid == 0) {
float s[SK];
for (int j=0;j<SK;j++) {
float dot = 0.0f;
for (int d=0;d<HD;d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j*HD+d]);
s[j] = dot * scale;
}
float mx = -INFINITY;
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
for (int j=0;j<SK;j++) s[j] /= sm;
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);
o_ref[d] = ov;
}
}
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== Debug HD=64 PV: register-math vs SS MMA ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
bf16_t* h_o_mma = (bf16_t*)calloc(HD, sizeof(bf16_t));
float* h_o_ref = (float*)calloc(HD, sizeof(float));
float* h_o_regmath = (float*)calloc(HD, sizeof(float));
srand(42);
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
bf16_t *d_q,*d_k,*d_v,*d_o_mma; float *d_o_ref, *d_o_regmath;
cudaMalloc(&d_q, HD*sizeof(bf16_t));
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
cudaMalloc(&d_o_mma, HD*sizeof(bf16_t));
cudaMalloc(&d_o_ref, HD*sizeof(float));
cudaMalloc(&d_o_regmath, HD*sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4+16 + NKT_QK*TILE_SZ*2 + NKT_QK*TILE_SZ*2 + TILE_SZ*2 + NKT_PV*V_TILE_SZ*2 + SK*4 + 256 + 127) & ~127;
printf("SMEM: %d bytes (%.1f KB)\n", smem, smem/1024.0f);
cudaFuncSetAttribute(test_fmha_hd64_debug, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
test_fmha_hd64_debug<<<1, 128, smem>>>(d_q, d_k, d_v, d_o_mma, d_o_ref, d_o_regmath, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_o_mma, d_o_mma, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost);
cudaMemcpy(h_o_ref, d_o_ref, HD*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_o_regmath, d_o_regmath, HD*sizeof(float), cudaMemcpyDeviceToHost);
// Compare register-math PV vs reference
printf("\n--- Register-math PV vs FP32 reference ---\n");
float cs_reg = 0, na_reg = 0, nb_reg = 0;
for (int d=0;d<HD;d++) {
float a = h_o_regmath[d], b = h_o_ref[d];
cs_reg += a*b; na_reg += a*a; nb_reg += b*b;
}
cs_reg /= (sqrtf(na_reg)*sqrtf(nb_reg)+1e-10f);
printf("Cosine: %.8f\n", cs_reg);
// Compare MMA PV vs reference
printf("\n--- MMA PV vs FP32 reference (all 64 elements) ---\n");
for (int d=0;d<HD;d++) printf(" MMA[%2d]=%10.6f ref[%2d]=%10.6f ratio=%.4f\n",
d, bf16_to_f32_host(h_o_mma[d]), d, h_o_ref[d],
fabsf(h_o_ref[d])>1e-6f ? bf16_to_f32_host(h_o_mma[d])/h_o_ref[d] : 0);
float cs_mma = 0, na_mma = 0, nb_mma = 0;
for (int d=0;d<HD;d++) {
float a = bf16_to_f32_host(h_o_mma[d]), b = h_o_ref[d];
if (fabsf(b)>1e-4f) { cs_mma += a*b; na_mma += a*a; nb_mma += b*b; }
}
cs_mma /= (sqrtf(na_mma)*sqrtf(nb_mma)+1e-10f);
printf("Filtered cosine: %.8f\n", cs_mma);
// Compare MMA PV vs register-math PV
printf("\n--- MMA PV vs register-math PV ---\n");
float cs_cmp = 0, na_cmp = 0, nb_cmp = 0;
for (int d=0;d<HD;d++) {
float a = bf16_to_f32_host(h_o_mma[d]), b = h_o_regmath[d];
cs_cmp += a*b; na_cmp += a*a; nb_cmp += b*b;
}
cs_cmp /= (sqrtf(na_cmp)*sqrtf(nb_cmp)+1e-10f);
printf("Cosine: %.8f\n", cs_cmp);
// Check if MMA output is proportional to reference (scale factor issue)
float ratio_sum = 0; int rc = 0;
for (int d=0;d<HD;d++) {
float a = bf16_to_f32_host(h_o_mma[d]), b = h_o_ref[d];
if (fabsf(b)>1e-6f) { ratio_sum += a/b; rc++; }
}
float avg_ratio = rc>0 ? ratio_sum/rc : 0;
printf("Average MMA/ref ratio: %.6f\n", avg_ratio);
printf("\nTest %s\n", cs_cmp > 0.999f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v);
cudaFree(d_o_mma); cudaFree(d_o_ref); cudaFree(d_o_regmath);
free(h_q); free(h_k); free(h_v); free(h_o_mma); free(h_o_ref); free(h_o_regmath);
return cs_cmp > 0.999f ? 0 : 1;
}