debug: compare only first HD_CHUNK values
This commit is contained in:
@@ -180,8 +180,9 @@ static int test_single(int T, int s_k, int n_h = 1, int batch = 1) {
|
||||
o_ref, nullptr, HD, T, s_k, SCALE);
|
||||
|
||||
float cs = 0, na = 0, nb = 0;
|
||||
int check_hd = HD_CHUNK; // Only check first hd_chunk values (for partial debug)
|
||||
for (int t = 0; t < T; t++) {
|
||||
for (int d = 0; d < HD; d++) {
|
||||
for (int d = 0; d < check_hd; d++) {
|
||||
float a = bf16_to_f32_host(h_o[h * MAX_T * HD + t * HD + d]);
|
||||
float b = o_ref[t * HD + d];
|
||||
if (fabsf(b) > 1e-4f) { cs += a * b; na += a * a; nb += b * b; }
|
||||
|
||||
Reference in New Issue
Block a user