diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 121804dc..719027f3 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -225,6 +225,11 @@ int main() { cudaMemcpy(h_s_out, d_s_out, 128*16*sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost); + printf("Q[d=0..7] SMEM: "); for(int d=0;d<8;d++) printf("%.4f ", h_s_out[200+d]); printf("\n"); + printf("Q[d=16..23] SMEM: "); for(int d=0;d<8;d++) printf("%.4f ", h_s_out[208+d]); printf("\n"); + printf("Q[d=0..7] orig: "); for(int d=0;d<8;d++) printf("%.4f ", bf16_to_f32_host(h_q[d])); printf("\n"); + printf("Q[d=16..23] orig: "); for(int d=0;d<8;d++) printf("%.4f ", bf16_to_f32_host(h_q[16+d])); printf("\n"); + // Compare row 0 printf("S[0,0..7] MMA: "); for (int c = 0; c < 8; c++) printf("%.6f ", h_s_out[0*16+c]);