Verify V SMEM values vs GMEM for HD=64

This commit is contained in:
2026-05-28 15:19:31 +00:00
parent bafd26707b
commit e1daad6955

View File

@@ -204,6 +204,22 @@ test_fmha_hd64_smem_p(const bf16_t* q, const bf16_t* k, const bf16_t* v,
if (tid == 0) { printf("RegPV[0..7]: "); for(int d=0;d<8;d++) printf("%.6f ", s_ref_o[d]); printf("\n"); }
__syncthreads();
// Verify V SMEM values for kt=0
if (tid == 0) {
bf16_t* sv0 = sV; // kt=0, nt=0
// Read a few positions and compare with v[]
for (int d = 0; d < 4; d++) {
for (int r = 0; r < 4; r++) {
int g_mn = d / 8, g_k = r / 8;
int llr = d % 8, lc = r % 8;
int idx = g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc;
printf("V_smem[%d,%d]=%.4f vs V_gmem=%.4f\n", d, r,
bf16_to_f32(sv0[idx]), bf16_to_f32(v[d * SK + r]));
}
}
}
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}