From 7eb85a71fcad29fabedd9e52cb5d2409895c72f2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 09:37:07 +0000 Subject: [PATCH] test: add Q SMEM verification output + bf16_to_f32_host --- tests/unit/test_umma_qk.cu | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index 7884a279..a23b04b2 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -19,6 +19,11 @@ static bf16_t f32_to_bf16_host(float f) { return (uint16_t)(u >> 16); } +static float bf16_to_f32_host(bf16_t h) { + uint32_t u = (uint32_t)h << 16; + float f; memcpy(&f, &u, 4); return f; +} + __global__ void __launch_bounds__(NTHREADS) test_umma_qk_hd16( const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, @@ -187,7 +192,11 @@ int main() { printf("desc_k=0x%016lx (addr=%lu,LBO=%lu,SBO=%lu)\n", dk, dk&0x3FFF, (dk>>16)&0x3FFF, (dk>>32)&0x3FFF); printf("idesc=0x%08x, tmem_base=%.0f\n", idi, h_s_out[135]); - printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED"); + printf("Q from SMEM: "); + for (int d = 0; d < 16; d++) printf("%.4f ", h_s_out[160 + d]); + printf("\nQ original: "); + for (int d = 0; d < 16; d++) printf("%.4f ", bf16_to_f32_host(h_q[d])); + printf("\n"); cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar); free(h_q); free(h_k); free(h_s_out); free(h_s_scalar); return (rel_err < 0.01f) ? 0 : 1;