P5: fix BF16 host helpers for standalone test
This commit is contained in:
@@ -12,6 +12,15 @@
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
// Host-side BF16 helpers
|
||||
static float hbf16_to_f32(uint16_t h) {
|
||||
uint32_t u = ((uint32_t)h) << 16;
|
||||
float f; memcpy(&f, &u, 4); return f;
|
||||
}
|
||||
static uint16_t hf32_to_bf16(float f) {
|
||||
uint32_t u; memcpy(&u, &f, 4); return (uint16_t)(u >> 16);
|
||||
}
|
||||
|
||||
// CPU reference attention for single head
|
||||
void reference_attention(
|
||||
const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
@@ -23,7 +32,7 @@ void reference_attention(
|
||||
// Compute max
|
||||
for (int j = 0; j < s_k; j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < hd; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]);
|
||||
for (int d = 0; d < hd; d++) dot += hbf16_to_f32(q[d]) * hbf16_to_f32(k[j * hd + d]);
|
||||
dot *= scale;
|
||||
if (dot > row_max) row_max = dot;
|
||||
}
|
||||
@@ -31,11 +40,11 @@ void reference_attention(
|
||||
float row_sum = 0.0f;
|
||||
for (int j = 0; j < s_k; j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < hd; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]);
|
||||
for (int d = 0; d < hd; d++) dot += hbf16_to_f32(q[d]) * hbf16_to_f32(k[j * hd + d]);
|
||||
dot *= scale;
|
||||
float p = expf(dot - row_max);
|
||||
row_sum += p;
|
||||
for (int d = 0; d < hd; d++) o_ref[d] += p * bf16_to_f32(v[d * s_k + j]);
|
||||
for (int d = 0; d < hd; d++) o_ref[d] += p * hbf16_to_f32(v[d * s_k + j]);
|
||||
}
|
||||
for (int d = 0; d < hd; d++) o_ref[d] /= row_sum;
|
||||
*lse_ref = logf(row_sum) + row_max;
|
||||
@@ -51,9 +60,9 @@ int main() {
|
||||
float h_o_ref[HD], h_lse_ref;
|
||||
|
||||
srand(42);
|
||||
for (int d = 0; d < HD; d++) h_q[d] = f32_to_bf16((float)(rand() % 100) / 100.0f);
|
||||
for (int j = 0; j < SK * HD; j++) h_k[j] = f32_to_bf16((float)(rand() % 100) / 100.0f);
|
||||
for (int j = 0; j < HD * SK; j++) h_v[j] = f32_to_bf16((float)(rand() % 100) / 100.0f);
|
||||
for (int d = 0; d < HD; d++) h_q[d] = hf32_to_bf16((float)(rand() % 100) / 100.0f);
|
||||
for (int j = 0; j < SK * HD; j++) h_k[j] = hf32_to_bf16((float)(rand() % 100) / 100.0f);
|
||||
for (int j = 0; j < HD * SK; j++) h_v[j] = hf32_to_bf16((float)(rand() % 100) / 100.0f);
|
||||
|
||||
// CPU reference
|
||||
reference_attention(h_q, h_k, h_v, h_o_ref, &h_lse_ref, HD, SK, SCALE);
|
||||
@@ -124,7 +133,7 @@ int main() {
|
||||
float cos = 0, norm_a = 0, norm_b = 0;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
float a = h_o_ref[d];
|
||||
float b = bf16_to_f32(h_o[d]);
|
||||
float b = hbf16_to_f32(h_o[d]);
|
||||
cos += a * b;
|
||||
norm_a += a * a;
|
||||
norm_b += b * b;
|
||||
@@ -135,7 +144,7 @@ int main() {
|
||||
printf(" LSE: kernel=%.4f ref=%.4f\n", h_lse, h_lse_ref);
|
||||
printf(" Cosine similarity: %.6f\n", cos);
|
||||
printf(" Kernel O[0:5]:");
|
||||
for (int d = 0; d < 5; d++) printf(" %.4f", bf16_to_f32(h_o[d]));
|
||||
for (int d = 0; d < 5; d++) printf(" %.4f", hbf16_to_f32(h_o[d]));
|
||||
printf("\n Ref O[0:5]:");
|
||||
for (int d = 0; d < 5; d++) printf(" %.4f", h_o_ref[d]);
|
||||
printf("\n");
|
||||
|
||||
Reference in New Issue
Block a user