test: add Q SMEM verification output + bf16_to_f32_host
This commit is contained in:
@@ -19,6 +19,11 @@ static bf16_t f32_to_bf16_host(float f) {
|
||||
return (uint16_t)(u >> 16);
|
||||
}
|
||||
|
||||
static float bf16_to_f32_host(bf16_t h) {
|
||||
uint32_t u = (uint32_t)h << 16;
|
||||
float f; memcpy(&f, &u, 4); return f;
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(NTHREADS)
|
||||
test_umma_qk_hd16(
|
||||
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
@@ -187,7 +192,11 @@ int main() {
|
||||
printf("desc_k=0x%016lx (addr=%lu,LBO=%lu,SBO=%lu)\n", dk, dk&0x3FFF, (dk>>16)&0x3FFF, (dk>>32)&0x3FFF);
|
||||
printf("idesc=0x%08x, tmem_base=%.0f\n", idi, h_s_out[135]);
|
||||
|
||||
printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED");
|
||||
printf("Q from SMEM: ");
|
||||
for (int d = 0; d < 16; d++) printf("%.4f ", h_s_out[160 + d]);
|
||||
printf("\nQ original: ");
|
||||
for (int d = 0; d < 16; d++) printf("%.4f ", bf16_to_f32_host(h_q[d]));
|
||||
printf("\n");
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);
|
||||
free(h_q); free(h_k); free(h_s_out); free(h_s_scalar);
|
||||
return (rel_err < 0.01f) ? 0 : 1;
|
||||
|
||||
Reference in New Issue
Block a user