Verify V SMEM values vs GMEM for HD=64
This commit is contained in:
@@ -204,6 +204,22 @@ test_fmha_hd64_smem_p(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
if (tid == 0) { printf("RegPV[0..7]: "); for(int d=0;d<8;d++) printf("%.6f ", s_ref_o[d]); printf("\n"); }
|
||||
__syncthreads();
|
||||
|
||||
// Verify V SMEM values for kt=0
|
||||
if (tid == 0) {
|
||||
bf16_t* sv0 = sV; // kt=0, nt=0
|
||||
// Read a few positions and compare with v[]
|
||||
for (int d = 0; d < 4; d++) {
|
||||
for (int r = 0; r < 4; r++) {
|
||||
int g_mn = d / 8, g_k = r / 8;
|
||||
int llr = d % 8, lc = r % 8;
|
||||
int idx = g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc;
|
||||
printf("V_smem[%d,%d]=%.4f vs V_gmem=%.4f\n", d, r,
|
||||
bf16_to_f32(sv0[idx]), bf16_to_f32(v[d * SK + r]));
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (wid == 0) tmem_dealloc(tb, 128);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user