diff --git a/tests/unit/test_p6_tma_store.cu b/tests/unit/test_p6_tma_store.cu index b97a7b45..bce8a4cf 100644 --- a/tests/unit/test_p6_tma_store.cu +++ b/tests/unit/test_p6_tma_store.cu @@ -34,7 +34,7 @@ int main() { constexpr int N = 128; constexpr int n_h = 4; constexpr int batch = 1; - constexpr float scale = 1.0f / sqrtf((float)HD); + const float scale = 1.0f / sqrtf((float)HD); // Allocate tensors bf16_t *d_q, *d_k, *d_v; @@ -151,10 +151,11 @@ int main() { float* f_direct = new float[n_h * HD]; float* f_tma = new float[n_h * HD]; - for (int i = 0; i < n_h * HD; i++) { - f_direct[i] = bf16_to_f32(h_o_direct[i]); - f_tma[i] = bf16_to_f32(h_o_tma[i]); - } + auto b2f = [](bf16_t h) -> float { float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f; }; + for (int i = 0; i < n_h * HD; i++) { + f_direct[i] = b2f(h_o_direct[i]); + f_tma[i] = b2f(h_o_tma[i]); + } float cos = cosine_sim(f_direct, f_tma, n_h * HD); printf("P6 TMA epilogue test (hd=%d, n_h=%d): cos=%.8f\n", HD, n_h, cos);