test: add Q SMEM verification output + bf16_to_f32_host

This commit is contained in:
2026-05-28 09:37:07 +00:00
parent 8f23c2aaf6
commit 7eb85a71fc

View File

@@ -19,6 +19,11 @@ static bf16_t f32_to_bf16_host(float f) {
return (uint16_t)(u >> 16);
}
static float bf16_to_f32_host(bf16_t h) {
uint32_t u = (uint32_t)h << 16;
float f; memcpy(&f, &u, 4); return f;
}
__global__ void __launch_bounds__(NTHREADS)
test_umma_qk_hd16(
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
@@ -187,7 +192,11 @@ int main() {
printf("desc_k=0x%016lx (addr=%lu,LBO=%lu,SBO=%lu)\n", dk, dk&0x3FFF, (dk>>16)&0x3FFF, (dk>>32)&0x3FFF);
printf("idesc=0x%08x, tmem_base=%.0f\n", idi, h_s_out[135]);
printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED");
printf("Q from SMEM: ");
for (int d = 0; d < 16; d++) printf("%.4f ", h_s_out[160 + d]);
printf("\nQ original: ");
for (int d = 0; d < 16; d++) printf("%.4f ", bf16_to_f32_host(h_q[d]));
printf("\n");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);
free(h_q); free(h_k); free(h_s_out); free(h_s_scalar);
return (rel_err < 0.01f) ? 0 : 1;