From 3da31de4c072f7256bf0d3477ca75e78078183e4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 09:04:05 +0000 Subject: [PATCH] P5: fix BF16 host helpers for standalone test --- tests/unit/test_p5_multitile.cu | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_p5_multitile.cu b/tests/unit/test_p5_multitile.cu index 3f567775..c10d4a2d 100644 --- a/tests/unit/test_p5_multitile.cu +++ b/tests/unit/test_p5_multitile.cu @@ -12,6 +12,15 @@ using namespace dsv4::kernels::attention; +// Host-side BF16 helpers +static float hbf16_to_f32(uint16_t h) { + uint32_t u = ((uint32_t)h) << 16; + float f; memcpy(&f, &u, 4); return f; +} +static uint16_t hf32_to_bf16(float f) { + uint32_t u; memcpy(&u, &f, 4); return (uint16_t)(u >> 16); +} + // CPU reference attention for single head void reference_attention( const bf16_t* q, const bf16_t* k, const bf16_t* v, @@ -23,7 +32,7 @@ void reference_attention( // Compute max for (int j = 0; j < s_k; j++) { float dot = 0.0f; - for (int d = 0; d < hd; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]); + for (int d = 0; d < hd; d++) dot += hbf16_to_f32(q[d]) * hbf16_to_f32(k[j * hd + d]); dot *= scale; if (dot > row_max) row_max = dot; } @@ -31,11 +40,11 @@ void reference_attention( float row_sum = 0.0f; for (int j = 0; j < s_k; j++) { float dot = 0.0f; - for (int d = 0; d < hd; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]); + for (int d = 0; d < hd; d++) dot += hbf16_to_f32(q[d]) * hbf16_to_f32(k[j * hd + d]); dot *= scale; float p = expf(dot - row_max); row_sum += p; - for (int d = 0; d < hd; d++) o_ref[d] += p * bf16_to_f32(v[d * s_k + j]); + for (int d = 0; d < hd; d++) o_ref[d] += p * hbf16_to_f32(v[d * s_k + j]); } for (int d = 0; d < hd; d++) o_ref[d] /= row_sum; *lse_ref = logf(row_sum) + row_max; @@ -51,9 +60,9 @@ int main() { float h_o_ref[HD], h_lse_ref; srand(42); - for (int d = 0; d < HD; d++) h_q[d] = f32_to_bf16((float)(rand() % 100) / 100.0f); - for (int j = 0; j < SK * HD; j++) h_k[j] = f32_to_bf16((float)(rand() % 100) / 100.0f); - for (int j = 0; j < HD * SK; j++) h_v[j] = f32_to_bf16((float)(rand() % 100) / 100.0f); + for (int d = 0; d < HD; d++) h_q[d] = hf32_to_bf16((float)(rand() % 100) / 100.0f); + for (int j = 0; j < SK * HD; j++) h_k[j] = hf32_to_bf16((float)(rand() % 100) / 100.0f); + for (int j = 0; j < HD * SK; j++) h_v[j] = hf32_to_bf16((float)(rand() % 100) / 100.0f); // CPU reference reference_attention(h_q, h_k, h_v, h_o_ref, &h_lse_ref, HD, SK, SCALE); @@ -124,7 +133,7 @@ int main() { float cos = 0, norm_a = 0, norm_b = 0; for (int d = 0; d < HD; d++) { float a = h_o_ref[d]; - float b = bf16_to_f32(h_o[d]); + float b = hbf16_to_f32(h_o[d]); cos += a * b; norm_a += a * a; norm_b += b * b; @@ -135,7 +144,7 @@ int main() { printf(" LSE: kernel=%.4f ref=%.4f\n", h_lse, h_lse_ref); printf(" Cosine similarity: %.6f\n", cos); printf(" Kernel O[0:5]:"); - for (int d = 0; d < 5; d++) printf(" %.4f", bf16_to_f32(h_o[d])); + for (int d = 0; d < 5; d++) printf(" %.4f", hbf16_to_f32(h_o[d])); printf("\n Ref O[0:5]:"); for (int d = 0; d < 5; d++) printf(" %.4f", h_o_ref[d]); printf("\n");