Debug: print P values for HD=64

This commit is contained in:
2026-05-28 15:07:55 +00:00
parent 4b052f22a5
commit 6ea7356fdd

View File

@@ -126,6 +126,7 @@ test_fmha_hd64_smem_p(const bf16_t* q, const bf16_t* k, const bf16_t* v,
}
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
if (lane == 0) { printf("P[0..7]: "); for(int j=0;j<8;j++) printf("%.6f ", s_vals[j]); printf("\n"); }
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
}
__syncthreads();