debug: compare only first HD_CHUNK values

This commit is contained in:
2026-05-30 06:59:39 +00:00
parent 9227b0e93f
commit 72779e7f71

View File

@@ -180,8 +180,9 @@ static int test_single(int T, int s_k, int n_h = 1, int batch = 1) {
o_ref, nullptr, HD, T, s_k, SCALE);
float cs = 0, na = 0, nb = 0;
int check_hd = HD_CHUNK; // Only check first hd_chunk values (for partial debug)
for (int t = 0; t < T; t++) {
for (int d = 0; d < HD; d++) {
for (int d = 0; d < check_hd; d++) {
float a = bf16_to_f32_host(h_o[h * MAX_T * HD + t * HD + d]);
float b = o_ref[t * HD + d];
if (fabsf(b) > 1e-4f) { cs += a * b; na += a * a; nb += b * b; }