Debug: print P values for HD=64
This commit is contained in:
@@ -126,6 +126,7 @@ test_fmha_hd64_smem_p(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
}
|
||||
row_sum = wsum(row_sum);
|
||||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
|
||||
if (lane == 0) { printf("P[0..7]: "); for(int j=0;j<8;j++) printf("%.6f ", s_vals[j]); printf("\n"); }
|
||||
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Reference in New Issue
Block a user