diff --git a/tests/unit/test_fmha_hd64_smem_p.cu b/tests/unit/test_fmha_hd64_smem_p.cu index ce0cdd7d..6ac8fa07 100644 --- a/tests/unit/test_fmha_hd64_smem_p.cu +++ b/tests/unit/test_fmha_hd64_smem_p.cu @@ -204,6 +204,22 @@ test_fmha_hd64_smem_p(const bf16_t* q, const bf16_t* k, const bf16_t* v, if (tid == 0) { printf("RegPV[0..7]: "); for(int d=0;d<8;d++) printf("%.6f ", s_ref_o[d]); printf("\n"); } __syncthreads(); + // Verify V SMEM values for kt=0 + if (tid == 0) { + bf16_t* sv0 = sV; // kt=0, nt=0 + // Read a few positions and compare with v[] + for (int d = 0; d < 4; d++) { + for (int r = 0; r < 4; r++) { + int g_mn = d / 8, g_k = r / 8; + int llr = d % 8, lc = r % 8; + int idx = g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc; + printf("V_smem[%d,%d]=%.4f vs V_gmem=%.4f\n", d, r, + bf16_to_f32(sv0[idx]), bf16_to_f32(v[d * SK + r])); + } + } + } + __syncthreads(); + if (wid == 0) tmem_dealloc(tb, 128); }