Clean up HD=64 test, V layout verified correct
This commit is contained in:
@@ -189,37 +189,6 @@ test_fmha_hd64_smem_p(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Register-math PV for verification =====
|
||||
float* s_ref_o = s_p_vals; // Reuse s_p_vals (already consumed by PV)
|
||||
if (tid == 0) {
|
||||
float s_ref[SK];
|
||||
for (int j=0;j<SK;j++) s_ref[j] = s_p_vals[j];
|
||||
for (int d=0;d<HD;d++) {
|
||||
float ov = 0.0f;
|
||||
for (int j=0;j<SK;j++) ov += s_ref[j] * bf16_to_f32(v[d*SK+j]);
|
||||
s_ref_o[d] = ov;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) { printf("RegPV[0..7]: "); for(int d=0;d<8;d++) printf("%.6f ", s_ref_o[d]); printf("\n"); }
|
||||
__syncthreads();
|
||||
|
||||
// Verify V SMEM values for kt=0
|
||||
if (tid == 0) {
|
||||
bf16_t* sv0 = sV; // kt=0, nt=0
|
||||
// Read a few positions and compare with v[]
|
||||
for (int d = 0; d < 4; d++) {
|
||||
for (int r = 0; r < 4; r++) {
|
||||
int g_mn = d / 8, g_k = r / 8;
|
||||
int llr = d % 8, lc = r % 8;
|
||||
int idx = g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc;
|
||||
printf("V_smem[%d,%d]=%.4f vs V_gmem=%.4f\n", d, r,
|
||||
bf16_to_f32(sv0[idx]), bf16_to_f32(v[d * SK + r]));
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (wid == 0) tmem_dealloc(tb, 128);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user