P5: fix BF16 host helpers for standalone test

This commit is contained in:
2026-05-30 09:04:05 +00:00
parent 9e6ba25a98
commit 3da31de4c0

View File

@@ -12,6 +12,15 @@
using namespace dsv4::kernels::attention;
// Host-side BF16 helpers
static float hbf16_to_f32(uint16_t h) {
uint32_t u = ((uint32_t)h) << 16;
float f; memcpy(&f, &u, 4); return f;
}
static uint16_t hf32_to_bf16(float f) {
uint32_t u; memcpy(&u, &f, 4); return (uint16_t)(u >> 16);
}
// CPU reference attention for single head
void reference_attention(
const bf16_t* q, const bf16_t* k, const bf16_t* v,
@@ -23,7 +32,7 @@ void reference_attention(
// Compute max
for (int j = 0; j < s_k; j++) {
float dot = 0.0f;
for (int d = 0; d < hd; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]);
for (int d = 0; d < hd; d++) dot += hbf16_to_f32(q[d]) * hbf16_to_f32(k[j * hd + d]);
dot *= scale;
if (dot > row_max) row_max = dot;
}
@@ -31,11 +40,11 @@ void reference_attention(
float row_sum = 0.0f;
for (int j = 0; j < s_k; j++) {
float dot = 0.0f;
for (int d = 0; d < hd; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]);
for (int d = 0; d < hd; d++) dot += hbf16_to_f32(q[d]) * hbf16_to_f32(k[j * hd + d]);
dot *= scale;
float p = expf(dot - row_max);
row_sum += p;
for (int d = 0; d < hd; d++) o_ref[d] += p * bf16_to_f32(v[d * s_k + j]);
for (int d = 0; d < hd; d++) o_ref[d] += p * hbf16_to_f32(v[d * s_k + j]);
}
for (int d = 0; d < hd; d++) o_ref[d] /= row_sum;
*lse_ref = logf(row_sum) + row_max;
@@ -51,9 +60,9 @@ int main() {
float h_o_ref[HD], h_lse_ref;
srand(42);
for (int d = 0; d < HD; d++) h_q[d] = f32_to_bf16((float)(rand() % 100) / 100.0f);
for (int j = 0; j < SK * HD; j++) h_k[j] = f32_to_bf16((float)(rand() % 100) / 100.0f);
for (int j = 0; j < HD * SK; j++) h_v[j] = f32_to_bf16((float)(rand() % 100) / 100.0f);
for (int d = 0; d < HD; d++) h_q[d] = hf32_to_bf16((float)(rand() % 100) / 100.0f);
for (int j = 0; j < SK * HD; j++) h_k[j] = hf32_to_bf16((float)(rand() % 100) / 100.0f);
for (int j = 0; j < HD * SK; j++) h_v[j] = hf32_to_bf16((float)(rand() % 100) / 100.0f);
// CPU reference
reference_attention(h_q, h_k, h_v, h_o_ref, &h_lse_ref, HD, SK, SCALE);
@@ -124,7 +133,7 @@ int main() {
float cos = 0, norm_a = 0, norm_b = 0;
for (int d = 0; d < HD; d++) {
float a = h_o_ref[d];
float b = bf16_to_f32(h_o[d]);
float b = hbf16_to_f32(h_o[d]);
cos += a * b;
norm_a += a * a;
norm_b += b * b;
@@ -135,7 +144,7 @@ int main() {
printf(" LSE: kernel=%.4f ref=%.4f\n", h_lse, h_lse_ref);
printf(" Cosine similarity: %.6f\n", cos);
printf(" Kernel O[0:5]:");
for (int d = 0; d < 5; d++) printf(" %.4f", bf16_to_f32(h_o[d]));
for (int d = 0; d < 5; d++) printf(" %.4f", hbf16_to_f32(h_o[d]));
printf("\n Ref O[0:5]:");
for (int d = 0; d < 5; d++) printf(" %.4f", h_o_ref[d]);
printf("\n");